diff options
author | Peter Dettman <peter.dettman@bouncycastle.org> | 2014-10-17 22:17:56 +0700 |
---|---|---|
committer | Peter Dettman <peter.dettman@bouncycastle.org> | 2014-10-17 22:17:56 +0700 |
commit | 9a560818c4b0981dc251ab02dace374560219f1e (patch) | |
tree | 3056a31f2e8374fa33910eaf5167734468a9f967 /crypto | |
parent | Implement draft-bmoeller-tls-downgrade-scsv-02 (diff) | |
download | BouncyCastle.NET-ed25519-9a560818c4b0981dc251ab02dace374560219f1e.tar.xz |
Initial port of DTLS client/server from Java
Diffstat (limited to 'crypto')
-rw-r--r-- | crypto/crypto.csproj | 55 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DatagramTransport.cs | 23 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsClientProtocol.cs | 843 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsEpoch.cs | 51 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsHandshakeRetransmit.cs | 11 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsProtocol.cs | 72 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsReassembler.cs | 125 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsRecordLayer.cs | 507 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsReliableHandshake.cs | 443 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsReplayWindow.cs | 85 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsServerProtocol.cs | 642 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsTransport.cs | 77 |
12 files changed, 2934 insertions, 0 deletions
diff --git a/crypto/crypto.csproj b/crypto/crypto.csproj index 81f74e656..74aac8b6e 100644 --- a/crypto/crypto.csproj +++ b/crypto/crypto.csproj @@ -4434,6 +4434,11 @@ BuildAction = "Compile" /> <File + RelPath = "src\crypto\tls\DatagramTransport.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File RelPath = "src\crypto\tls\DefaultTlsAgreementCredentials.cs" SubType = "Code" BuildAction = "Compile" @@ -4479,6 +4484,56 @@ BuildAction = "Compile" /> <File + RelPath = "src\crypto\tls\DtlsClientProtocol.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File + RelPath = "src\crypto\tls\DtlsEpoch.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File + RelPath = "src\crypto\tls\DtlsHandshakeRetransmit.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File + RelPath = "src\crypto\tls\DtlsProtocol.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File + RelPath = "src\crypto\tls\DtlsReassembler.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File + RelPath = "src\crypto\tls\DtlsRecordLayer.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File + RelPath = "src\crypto\tls\DtlsReliableHandshake.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File + RelPath = "src\crypto\tls\DtlsReplayWindow.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File + RelPath = "src\crypto\tls\DtlsServerProtocol.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File + RelPath = "src\crypto\tls\DtlsTransport.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File RelPath = "src\crypto\tls\ECBasisType.cs" SubType = "Code" BuildAction = "Compile" diff --git a/crypto/src/crypto/tls/DatagramTransport.cs b/crypto/src/crypto/tls/DatagramTransport.cs new file mode 100644 index 000000000..524a8b181 --- /dev/null +++ b/crypto/src/crypto/tls/DatagramTransport.cs @@ -0,0 +1,23 @@ +using System; +using System.IO; + +namespace Org.BouncyCastle.Crypto.Tls +{ + public interface DatagramTransport + { + /// <exception cref="IOException"/> + int GetReceiveLimit(); + + /// <exception cref="IOException"/> + int GetSendLimit(); + + /// <exception cref="IOException"/> + int Receive(byte[] buf, int off, int len, int waitMillis); + + /// <exception cref="IOException"/> + void Send(byte[] buf, int off, int len); + + /// <exception cref="IOException"/> + void Close(); + } +} diff --git a/crypto/src/crypto/tls/DtlsClientProtocol.cs b/crypto/src/crypto/tls/DtlsClientProtocol.cs new file mode 100644 index 000000000..ae6ebbce8 --- /dev/null +++ b/crypto/src/crypto/tls/DtlsClientProtocol.cs @@ -0,0 +1,843 @@ +using System; +using System.Collections; +using System.IO; + +using Org.BouncyCastle.Security; +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Crypto.Tls +{ + public class DtlsClientProtocol + : DtlsProtocol + { + public DtlsClientProtocol(SecureRandom secureRandom) + : base(secureRandom) + { + } + + public virtual DtlsTransport Connect(TlsClient client, DatagramTransport transport) + { + if (client == null) + throw new ArgumentNullException("client"); + if (transport == null) + throw new ArgumentNullException("transport"); + + SecurityParameters securityParameters = new SecurityParameters(); + securityParameters.entity = ConnectionEnd.client; + + ClientHandshakeState state = new ClientHandshakeState(); + state.client = client; + state.clientContext = new TlsClientContextImpl(mSecureRandom, securityParameters); + + securityParameters.clientRandom = TlsProtocol.CreateRandomBlock(client.ShouldUseGmtUnixTime(), + state.clientContext.NonceRandomGenerator); + + client.Init(state.clientContext); + + DtlsRecordLayer recordLayer = new DtlsRecordLayer(transport, state.clientContext, client, ContentType.handshake); + + TlsSession sessionToResume = state.client.GetSessionToResume(); + if (sessionToResume != null) + { + SessionParameters sessionParameters = sessionToResume.ExportSessionParameters(); + if (sessionParameters != null) + { + state.tlsSession = sessionToResume; + state.sessionParameters = sessionParameters; + } + } + + try + { + return ClientHandshake(state, recordLayer); + } + catch (TlsFatalAlert fatalAlert) + { + recordLayer.Fail(fatalAlert.AlertDescription); + throw fatalAlert; + } + catch (IOException e) + { + recordLayer.Fail(AlertDescription.internal_error); + throw e; + } + catch (Exception e) + { + recordLayer.Fail(AlertDescription.internal_error); + throw new TlsFatalAlert(AlertDescription.internal_error, e); + } + } + + internal virtual DtlsTransport ClientHandshake(ClientHandshakeState state, DtlsRecordLayer recordLayer) + { + SecurityParameters securityParameters = state.clientContext.SecurityParameters; + DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.clientContext, recordLayer); + + byte[] clientHelloBody = GenerateClientHello(state, state.client); + handshake.SendMessage(HandshakeType.client_hello, clientHelloBody); + + DtlsReliableHandshake.Message serverMessage = handshake.ReceiveMessage(); + + while (serverMessage.Type == HandshakeType.hello_verify_request) + { + ProtocolVersion recordLayerVersion = recordLayer.ResetDiscoveredPeerVersion(); + ProtocolVersion client_version = state.clientContext.ClientVersion; + + /* + * RFC 6347 4.2.1 DTLS 1.2 server implementations SHOULD use DTLS version 1.0 regardless of + * the version of TLS that is expected to be negotiated. DTLS 1.2 and 1.0 clients MUST use + * the version solely to indicate packet formatting (which is the same in both DTLS 1.2 and + * 1.0) and not as part of version negotiation. + */ + if (!recordLayerVersion.IsEqualOrEarlierVersionOf(client_version)) + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + + byte[] cookie = ProcessHelloVerifyRequest(state, serverMessage.Body); + byte[] patched = PatchClientHelloWithCookie(clientHelloBody, cookie); + + handshake.ResetHandshakeMessagesDigest(); + handshake.SendMessage(HandshakeType.client_hello, patched); + + serverMessage = handshake.ReceiveMessage(); + } + + if (serverMessage.Type == HandshakeType.server_hello) + { + ReportServerVersion(state, recordLayer.DiscoveredPeerVersion); + + ProcessServerHello(state, serverMessage.Body); + } + else + { + throw new TlsFatalAlert(AlertDescription.unexpected_message); + } + + if (state.maxFragmentLength >= 0) + { + int plainTextLimit = 1 << (8 + state.maxFragmentLength); + recordLayer.SetPlaintextLimit(plainTextLimit); + } + + securityParameters.cipherSuite = state.selectedCipherSuite; + securityParameters.compressionAlgorithm = (byte)state.selectedCompressionMethod; + securityParameters.prfAlgorithm = TlsProtocol.GetPrfAlgorithm(state.clientContext, state.selectedCipherSuite); + + /* + * RFC 5264 7.4.9. Any cipher suite which does not explicitly specify verify_data_length has + * a verify_data_length equal to 12. This includes all existing cipher suites. + */ + securityParameters.verifyDataLength = 12; + + handshake.NotifyHelloComplete(); + + bool resumedSession = state.selectedSessionID.Length > 0 && state.tlsSession != null + && Arrays.AreEqual(state.selectedSessionID, state.tlsSession.SessionID); + + if (resumedSession) + { + if (securityParameters.CipherSuite != state.sessionParameters.CipherSuite + || securityParameters.CompressionAlgorithm != state.sessionParameters.CompressionAlgorithm) + { + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + } + + IDictionary sessionServerExtensions = state.sessionParameters.ReadServerExtensions(); + + securityParameters.extendedMasterSecret = TlsExtensionsUtilities.HasExtendedMasterSecretExtension(sessionServerExtensions); + + securityParameters.masterSecret = Arrays.Clone(state.sessionParameters.MasterSecret); + recordLayer.InitPendingEpoch(state.client.GetCipher()); + + // NOTE: Calculated exclusive of the actual Finished message from the server + byte[] resExpectedServerVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext, ExporterLabel.server_finished, + TlsProtocol.GetCurrentPrfHash(state.clientContext, handshake.HandshakeHash, null)); + ProcessFinished(handshake.ReceiveMessageBody(HandshakeType.finished), resExpectedServerVerifyData); + + // NOTE: Calculated exclusive of the Finished message itself + byte[] resClientVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext, ExporterLabel.client_finished, + TlsProtocol.GetCurrentPrfHash(state.clientContext, handshake.HandshakeHash, null)); + handshake.SendMessage(HandshakeType.finished, resClientVerifyData); + + handshake.Finish(); + + state.clientContext.SetResumableSession(state.tlsSession); + + state.client.NotifyHandshakeComplete(); + + return new DtlsTransport(recordLayer); + } + + InvalidateSession(state); + + if (state.selectedSessionID.Length > 0) + { + state.tlsSession = new TlsSessionImpl(state.selectedSessionID, null); + } + + serverMessage = handshake.ReceiveMessage(); + + if (serverMessage.Type == HandshakeType.supplemental_data) + { + ProcessServerSupplementalData(state, serverMessage.Body); + serverMessage = handshake.ReceiveMessage(); + } + else + { + state.client.ProcessServerSupplementalData(null); + } + + state.keyExchange = state.client.GetKeyExchange(); + state.keyExchange.Init(state.clientContext); + + Certificate serverCertificate = null; + + if (serverMessage.Type == HandshakeType.certificate) + { + serverCertificate = ProcessServerCertificate(state, serverMessage.Body); + serverMessage = handshake.ReceiveMessage(); + } + else + { + // Okay, Certificate is optional + state.keyExchange.SkipServerCredentials(); + } + + // TODO[RFC 3546] Check whether empty certificates is possible, allowed, or excludes CertificateStatus + if (serverCertificate == null || serverCertificate.IsEmpty) + { + state.allowCertificateStatus = false; + } + + if (serverMessage.Type == HandshakeType.certificate_status) + { + ProcessCertificateStatus(state, serverMessage.Body); + serverMessage = handshake.ReceiveMessage(); + } + else + { + // Okay, CertificateStatus is optional + } + + if (serverMessage.Type == HandshakeType.server_key_exchange) + { + ProcessServerKeyExchange(state, serverMessage.Body); + serverMessage = handshake.ReceiveMessage(); + } + else + { + // Okay, ServerKeyExchange is optional + state.keyExchange.SkipServerKeyExchange(); + } + + if (serverMessage.Type == HandshakeType.certificate_request) + { + ProcessCertificateRequest(state, serverMessage.Body); + + /* + * TODO Give the client a chance to immediately select the CertificateVerify hash + * algorithm here to avoid tracking the other hash algorithms unnecessarily? + */ + TlsUtilities.TrackHashAlgorithms(handshake.HandshakeHash, + state.certificateRequest.SupportedSignatureAlgorithms); + + serverMessage = handshake.ReceiveMessage(); + } + else + { + // Okay, CertificateRequest is optional + } + + if (serverMessage.Type == HandshakeType.server_hello_done) + { + if (serverMessage.Body.Length != 0) + { + throw new TlsFatalAlert(AlertDescription.decode_error); + } + } + else + { + throw new TlsFatalAlert(AlertDescription.unexpected_message); + } + + handshake.HandshakeHash.SealHashAlgorithms(); + + IList clientSupplementalData = state.client.GetClientSupplementalData(); + if (clientSupplementalData != null) + { + byte[] supplementalDataBody = GenerateSupplementalData(clientSupplementalData); + handshake.SendMessage(HandshakeType.supplemental_data, supplementalDataBody); + } + + if (state.certificateRequest != null) + { + state.clientCredentials = state.authentication.GetClientCredentials(state.certificateRequest); + + /* + * RFC 5246 If no suitable certificate is available, the client MUST send a certificate + * message containing no certificates. + * + * NOTE: In previous RFCs, this was SHOULD instead of MUST. + */ + Certificate clientCertificate = null; + if (state.clientCredentials != null) + { + clientCertificate = state.clientCredentials.Certificate; + } + if (clientCertificate == null) + { + clientCertificate = Certificate.EmptyChain; + } + + byte[] certificateBody = GenerateCertificate(clientCertificate); + handshake.SendMessage(HandshakeType.certificate, certificateBody); + } + + if (state.clientCredentials != null) + { + state.keyExchange.ProcessClientCredentials(state.clientCredentials); + } + else + { + state.keyExchange.SkipClientCredentials(); + } + + byte[] clientKeyExchangeBody = GenerateClientKeyExchange(state); + handshake.SendMessage(HandshakeType.client_key_exchange, clientKeyExchangeBody); + + TlsHandshakeHash prepareFinishHash = handshake.PrepareToFinish(); + securityParameters.sessionHash = TlsProtocol.GetCurrentPrfHash(state.clientContext, prepareFinishHash, null); + + TlsProtocol.EstablishMasterSecret(state.clientContext, state.keyExchange); + recordLayer.InitPendingEpoch(state.client.GetCipher()); + + if (state.clientCredentials != null && state.clientCredentials is TlsSignerCredentials) + { + TlsSignerCredentials signerCredentials = (TlsSignerCredentials)state.clientCredentials; + + /* + * RFC 5246 4.7. digitally-signed element needs SignatureAndHashAlgorithm from TLS 1.2 + */ + SignatureAndHashAlgorithm signatureAndHashAlgorithm; + byte[] hash; + + if (TlsUtilities.IsTlsV12(state.clientContext)) + { + signatureAndHashAlgorithm = signerCredentials.SignatureAndHashAlgorithm; + if (signatureAndHashAlgorithm == null) + { + throw new TlsFatalAlert(AlertDescription.internal_error); + } + + hash = prepareFinishHash.GetFinalHash(signatureAndHashAlgorithm.Hash); + } + else + { + signatureAndHashAlgorithm = null; + hash = securityParameters.SessionHash; + } + + byte[] signature = signerCredentials.GenerateCertificateSignature(hash); + DigitallySigned certificateVerify = new DigitallySigned(signatureAndHashAlgorithm, signature); + byte[] certificateVerifyBody = GenerateCertificateVerify(state, certificateVerify); + handshake.SendMessage(HandshakeType.certificate_verify, certificateVerifyBody); + } + + // NOTE: Calculated exclusive of the Finished message itself + byte[] clientVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext, ExporterLabel.client_finished, + TlsProtocol.GetCurrentPrfHash(state.clientContext, handshake.HandshakeHash, null)); + handshake.SendMessage(HandshakeType.finished, clientVerifyData); + + if (state.expectSessionTicket) + { + serverMessage = handshake.ReceiveMessage(); + if (serverMessage.Type == HandshakeType.session_ticket) + { + ProcessNewSessionTicket(state, serverMessage.Body); + } + else + { + throw new TlsFatalAlert(AlertDescription.unexpected_message); + } + } + + // NOTE: Calculated exclusive of the actual Finished message from the server + byte[] expectedServerVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext, ExporterLabel.server_finished, + TlsProtocol.GetCurrentPrfHash(state.clientContext, handshake.HandshakeHash, null)); + ProcessFinished(handshake.ReceiveMessageBody(HandshakeType.finished), expectedServerVerifyData); + + handshake.Finish(); + + if (state.tlsSession != null) + { + state.sessionParameters = new SessionParameters.Builder() + .SetCipherSuite(securityParameters.cipherSuite) + .SetCompressionAlgorithm(securityParameters.compressionAlgorithm) + .SetMasterSecret(securityParameters.masterSecret) + .SetPeerCertificate(serverCertificate) + .Build(); + + state.tlsSession = TlsUtilities.ImportSession(state.tlsSession.SessionID, state.sessionParameters); + + state.clientContext.SetResumableSession(state.tlsSession); + } + + state.client.NotifyHandshakeComplete(); + + return new DtlsTransport(recordLayer); + } + + protected virtual byte[] GenerateCertificateVerify(ClientHandshakeState state, DigitallySigned certificateVerify) + { + MemoryStream buf = new MemoryStream(); + certificateVerify.Encode(buf); + return buf.ToArray(); + } + + protected virtual byte[] GenerateClientHello(ClientHandshakeState state, TlsClient client) + { + MemoryStream buf = new MemoryStream(); + + ProtocolVersion client_version = client.ClientVersion; + if (!client_version.IsDtls) + throw new TlsFatalAlert(AlertDescription.internal_error); + + TlsClientContextImpl context = state.clientContext; + + context.SetClientVersion(client_version); + TlsUtilities.WriteVersion(client_version, buf); + + SecurityParameters securityParameters = context.SecurityParameters; + buf.Write(securityParameters.ClientRandom, 0, securityParameters.ClientRandom.Length); + + // Session ID + byte[] session_id = TlsUtilities.EmptyBytes; + if (state.tlsSession != null) + { + session_id = state.tlsSession.SessionID; + if (session_id == null || session_id.Length > 32) + { + session_id = TlsUtilities.EmptyBytes; + } + } + TlsUtilities.WriteOpaque8(session_id, buf); + + // Cookie + TlsUtilities.WriteOpaque8(TlsUtilities.EmptyBytes, buf); + + bool fallback = client.IsFallback; + + /* + * Cipher suites + */ + state.offeredCipherSuites = client.GetCipherSuites(); + + // Integer -> byte[] + state.clientExtensions = client.GetClientExtensions(); + + securityParameters.extendedMasterSecret = TlsExtensionsUtilities.HasExtendedMasterSecretExtension(state.clientExtensions); + + // Cipher Suites (and SCSV) + { + /* + * RFC 5746 3.4. The client MUST include either an empty "renegotiation_info" extension, + * or the TLS_EMPTY_RENEGOTIATION_INFO_SCSV signaling cipher suite value in the + * ClientHello. Including both is NOT RECOMMENDED. + */ + byte[] renegExtData = TlsUtilities.GetExtensionData(state.clientExtensions, ExtensionType.renegotiation_info); + bool noRenegExt = (null == renegExtData); + + bool noRenegSCSV = !Arrays.Contains(state.offeredCipherSuites, CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV); + + if (noRenegExt && noRenegSCSV) + { + // TODO Consider whether to default to a client extension instead + state.offeredCipherSuites = Arrays.Append(state.offeredCipherSuites, CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV); + } + + /* + * draft-bmoeller-tls-downgrade-scsv-02 4. If a client sends a + * ClientHello.client_version containing a lower value than the latest (highest-valued) + * version supported by the client, it SHOULD include the TLS_FALLBACK_SCSV cipher suite + * value in ClientHello.cipher_suites. + */ + if (fallback && !Arrays.Contains(state.offeredCipherSuites, CipherSuite.TLS_FALLBACK_SCSV)) + { + state.offeredCipherSuites = Arrays.Append(state.offeredCipherSuites, CipherSuite.TLS_FALLBACK_SCSV); + } + + TlsUtilities.WriteUint16ArrayWithUint16Length(state.offeredCipherSuites, buf); + } + + // TODO Add support for compression + // Compression methods + // state.offeredCompressionMethods = client.getCompressionMethods(); + state.offeredCompressionMethods = new byte[]{ CompressionMethod.cls_null }; + + TlsUtilities.WriteUint8ArrayWithUint8Length(state.offeredCompressionMethods, buf); + + // Extensions + if (state.clientExtensions != null) + { + TlsProtocol.WriteExtensions(buf, state.clientExtensions); + } + + return buf.ToArray(); + } + + protected virtual byte[] GenerateClientKeyExchange(ClientHandshakeState state) + { + MemoryStream buf = new MemoryStream(); + state.keyExchange.GenerateClientKeyExchange(buf); + return buf.ToArray(); + } + + protected virtual void InvalidateSession(ClientHandshakeState state) + { + if (state.sessionParameters != null) + { + state.sessionParameters.Clear(); + state.sessionParameters = null; + } + + if (state.tlsSession != null) + { + state.tlsSession.Invalidate(); + state.tlsSession = null; + } + } + + protected virtual void ProcessCertificateRequest(ClientHandshakeState state, byte[] body) + { + if (state.authentication == null) + { + /* + * RFC 2246 7.4.4. It is a fatal handshake_failure alert for an anonymous server to + * request client identification. + */ + throw new TlsFatalAlert(AlertDescription.handshake_failure); + } + + MemoryStream buf = new MemoryStream(body, false); + + state.certificateRequest = CertificateRequest.Parse(state.clientContext, buf); + + TlsProtocol.AssertEmpty(buf); + + state.keyExchange.ValidateCertificateRequest(state.certificateRequest); + } + + protected virtual void ProcessCertificateStatus(ClientHandshakeState state, byte[] body) + { + if (!state.allowCertificateStatus) + { + /* + * RFC 3546 3.6. If a server returns a "CertificateStatus" message, then the + * server MUST have included an extension of type "status_request" with empty + * "extension_data" in the extended server hello.. + */ + throw new TlsFatalAlert(AlertDescription.unexpected_message); + } + + MemoryStream buf = new MemoryStream(body, false); + + state.certificateStatus = CertificateStatus.Parse(buf); + + TlsProtocol.AssertEmpty(buf); + + // TODO[RFC 3546] Figure out how to provide this to the client/authentication. + } + + protected virtual byte[] ProcessHelloVerifyRequest(ClientHandshakeState state, byte[] body) + { + MemoryStream buf = new MemoryStream(body, false); + + ProtocolVersion server_version = TlsUtilities.ReadVersion(buf); + byte[] cookie = TlsUtilities.ReadOpaque8(buf); + + TlsProtocol.AssertEmpty(buf); + + // TODO Seems this behaviour is not yet in line with OpenSSL for DTLS 1.2 + // reportServerVersion(state, server_version); + if (!server_version.IsEqualOrEarlierVersionOf(state.clientContext.ClientVersion)) + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + + /* + * RFC 6347 This specification increases the cookie size limit to 255 bytes for greater + * future flexibility. The limit remains 32 for previous versions of DTLS. + */ + if (!ProtocolVersion.DTLSv12.IsEqualOrEarlierVersionOf(server_version) && cookie.Length > 32) + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + + return cookie; + } + + protected virtual void ProcessNewSessionTicket(ClientHandshakeState state, byte[] body) + { + MemoryStream buf = new MemoryStream(body, false); + + NewSessionTicket newSessionTicket = NewSessionTicket.Parse(buf); + + TlsProtocol.AssertEmpty(buf); + + state.client.NotifyNewSessionTicket(newSessionTicket); + } + + protected virtual Certificate ProcessServerCertificate(ClientHandshakeState state, byte[] body) + { + MemoryStream buf = new MemoryStream(body, false); + + Certificate serverCertificate = Certificate.Parse(buf); + + TlsProtocol.AssertEmpty(buf); + + state.keyExchange.ProcessServerCertificate(serverCertificate); + state.authentication = state.client.GetAuthentication(); + state.authentication.NotifyServerCertificate(serverCertificate); + + return serverCertificate; + } + + protected virtual void ProcessServerHello(ClientHandshakeState state, byte[] body) + { + SecurityParameters securityParameters = state.clientContext.SecurityParameters; + + MemoryStream buf = new MemoryStream(body, false); + + ProtocolVersion server_version = TlsUtilities.ReadVersion(buf); + ReportServerVersion(state, server_version); + + securityParameters.serverRandom = TlsUtilities.ReadFully(32, buf); + + state.selectedSessionID = TlsUtilities.ReadOpaque8(buf); + if (state.selectedSessionID.Length > 32) + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + state.client.NotifySessionID(state.selectedSessionID); + + state.selectedCipherSuite = TlsUtilities.ReadUint16(buf); + if (!Arrays.Contains(state.offeredCipherSuites, state.selectedCipherSuite) + || state.selectedCipherSuite == CipherSuite.TLS_NULL_WITH_NULL_NULL + || CipherSuite.IsScsv(state.selectedCipherSuite) + || !TlsUtilities.IsValidCipherSuiteForVersion(state.selectedCipherSuite, server_version)) + { + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + } + + ValidateSelectedCipherSuite(state.selectedCipherSuite, AlertDescription.illegal_parameter); + + state.client.NotifySelectedCipherSuite(state.selectedCipherSuite); + + state.selectedCompressionMethod = TlsUtilities.ReadUint8(buf); + if (!Arrays.Contains(state.offeredCompressionMethods, (byte)state.selectedCompressionMethod)) + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + state.client.NotifySelectedCompressionMethod((byte)state.selectedCompressionMethod); + + /* + * RFC3546 2.2 The extended server hello message format MAY be sent in place of the server + * hello message when the client has requested extended functionality via the extended + * client hello message specified in Section 2.1. ... Note that the extended server hello + * message is only sent in response to an extended client hello message. This prevents the + * possibility that the extended server hello message could "break" existing TLS 1.0 + * clients. + */ + + /* + * TODO RFC 3546 2.3 If [...] the older session is resumed, then the server MUST ignore + * extensions appearing in the client hello, and send a server hello containing no + * extensions. + */ + + // Integer -> byte[] + IDictionary serverExtensions = TlsProtocol.ReadExtensions(buf); + + /* + * draft-ietf-tls-session-hash-01 5.2. If a server receives the "extended_master_secret" + * extension, it MUST include the "extended_master_secret" extension in its ServerHello + * message. + */ + bool serverSentExtendedMasterSecret = TlsExtensionsUtilities.HasExtendedMasterSecretExtension(serverExtensions); + if (serverSentExtendedMasterSecret != securityParameters.extendedMasterSecret) + throw new TlsFatalAlert(AlertDescription.handshake_failure); + + /* + * RFC 3546 2.2 Note that the extended server hello message is only sent in response to an + * extended client hello message. However, see RFC 5746 exception below. We always include + * the SCSV, so an Extended Server Hello is always allowed. + */ + if (serverExtensions != null) + { + foreach (int extType in serverExtensions.Keys) + { + /* + * RFC 5746 3.6. Note that sending a "renegotiation_info" extension in response to a + * ClientHello containing only the SCSV is an explicit exception to the prohibition + * in RFC 5246, Section 7.4.1.4, on the server sending unsolicited extensions and is + * only allowed because the client is signaling its willingness to receive the + * extension via the TLS_EMPTY_RENEGOTIATION_INFO_SCSV SCSV. + */ + if (extType == ExtensionType.renegotiation_info) + continue; + + /* + * RFC 5246 7.4.1.4 An extension type MUST NOT appear in the ServerHello unless the + * same extension type appeared in the corresponding ClientHello. If a client + * receives an extension type in ServerHello that it did not request in the + * associated ClientHello, it MUST abort the handshake with an unsupported_extension + * fatal alert. + */ + if (null == TlsUtilities.GetExtensionData(state.clientExtensions, extType)) + throw new TlsFatalAlert(AlertDescription.unsupported_extension); + + /* + * draft-ietf-tls-session-hash-01 5.2. Implementation note: if the server decides to + * proceed with resumption, the extension does not have any effect. Requiring the + * extension to be included anyway makes the extension negotiation logic easier, + * because it does not depend on whether resumption is accepted or not. + */ + if (extType == ExtensionType.extended_master_secret) + continue; + + /* + * RFC 3546 2.3. If [...] the older session is resumed, then the server MUST ignore + * extensions appearing in the client hello, and send a server hello containing no + * extensions[.] + */ + // TODO[sessions] + // if (this.mResumedSession) + // { + // // TODO[compat-gnutls] GnuTLS test server sends server extensions e.g. ec_point_formats + // // TODO[compat-openssl] OpenSSL test server sends server extensions e.g. ec_point_formats + // // TODO[compat-polarssl] PolarSSL test server sends server extensions e.g. ec_point_formats + //// throw new TlsFatalAlert(AlertDescription.illegal_parameter); + // } + } + + /* + * RFC 5746 3.4. Client Behavior: Initial Handshake + */ + { + /* + * When a ServerHello is received, the client MUST check if it includes the + * "renegotiation_info" extension: + */ + byte[] renegExtData = (byte[])serverExtensions[ExtensionType.renegotiation_info]; + if (renegExtData != null) + { + /* + * If the extension is present, set the secure_renegotiation flag to TRUE. The + * client MUST then verify that the length of the "renegotiated_connection" + * field is zero, and if it is not, MUST abort the handshake (by sending a fatal + * handshake_failure alert). + */ + state.secure_renegotiation = true; + + if (!Arrays.ConstantTimeAreEqual(renegExtData, TlsProtocol.CreateRenegotiationInfo(TlsUtilities.EmptyBytes))) + throw new TlsFatalAlert(AlertDescription.handshake_failure); + } + } + + /* + * RFC 7366 3. If a server receives an encrypt-then-MAC request extension from a client + * and then selects a stream or Authenticated Encryption with Associated Data (AEAD) + * ciphersuite, it MUST NOT send an encrypt-then-MAC response extension back to the + * client. + */ + bool serverSentEncryptThenMAC = TlsExtensionsUtilities.HasEncryptThenMacExtension(serverExtensions); + if (serverSentEncryptThenMAC && !TlsUtilities.IsBlockCipherSuite(state.selectedCipherSuite)) + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + + securityParameters.encryptThenMac = serverSentEncryptThenMAC; + + state.maxFragmentLength = EvaluateMaxFragmentLengthExtension(state.clientExtensions, serverExtensions, + AlertDescription.illegal_parameter); + + securityParameters.truncatedHMac = TlsExtensionsUtilities.HasTruncatedHMacExtension(serverExtensions); + + state.allowCertificateStatus = TlsUtilities.HasExpectedEmptyExtensionData(serverExtensions, + ExtensionType.status_request, AlertDescription.illegal_parameter); + + state.expectSessionTicket = TlsUtilities.HasExpectedEmptyExtensionData(serverExtensions, + ExtensionType.session_ticket, AlertDescription.illegal_parameter); + } + + state.client.NotifySecureRenegotiation(state.secure_renegotiation); + + if (state.clientExtensions != null) + { + state.client.ProcessServerExtensions(serverExtensions); + } + } + + protected virtual void ProcessServerKeyExchange(ClientHandshakeState state, byte[] body) + { + MemoryStream buf = new MemoryStream(body, false); + + state.keyExchange.ProcessServerKeyExchange(buf); + + TlsProtocol.AssertEmpty(buf); + } + + protected virtual void ProcessServerSupplementalData(ClientHandshakeState state, byte[] body) + { + MemoryStream buf = new MemoryStream(body, false); + IList serverSupplementalData = TlsProtocol.ReadSupplementalDataMessage(buf); + state.client.ProcessServerSupplementalData(serverSupplementalData); + } + + protected virtual void ReportServerVersion(ClientHandshakeState state, ProtocolVersion server_version) + { + TlsClientContextImpl clientContext = state.clientContext; + ProtocolVersion currentServerVersion = clientContext.ServerVersion; + if (null == currentServerVersion) + { + clientContext.SetServerVersion(server_version); + state.client.NotifyServerVersion(server_version); + } + else if (!currentServerVersion.Equals(server_version)) + { + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + } + } + + protected static byte[] PatchClientHelloWithCookie(byte[] clientHelloBody, byte[] cookie) + { + int sessionIDPos = 34; + int sessionIDLength = TlsUtilities.ReadUint8(clientHelloBody, sessionIDPos); + + int cookieLengthPos = sessionIDPos + 1 + sessionIDLength; + int cookiePos = cookieLengthPos + 1; + + byte[] patched = new byte[clientHelloBody.Length + cookie.Length]; + Array.Copy(clientHelloBody, 0, patched, 0, cookieLengthPos); + TlsUtilities.CheckUint8(cookie.Length); + TlsUtilities.WriteUint8((byte)cookie.Length, patched, cookieLengthPos); + Array.Copy(cookie, 0, patched, cookiePos, cookie.Length); + Array.Copy(clientHelloBody, cookiePos, patched, cookiePos + cookie.Length, clientHelloBody.Length - cookiePos); + + return patched; + } + + protected internal class ClientHandshakeState + { + internal TlsClient client = null; + internal TlsClientContextImpl clientContext = null; + internal TlsSession tlsSession = null; + internal SessionParameters sessionParameters = null; + internal SessionParameters.Builder sessionParametersBuilder = null; + internal int[] offeredCipherSuites = null; + internal byte[] offeredCompressionMethods = null; + internal IDictionary clientExtensions = null; + internal byte[] selectedSessionID = null; + internal int selectedCipherSuite = -1; + internal short selectedCompressionMethod = -1; + internal bool secure_renegotiation = false; + internal short maxFragmentLength = -1; + internal bool allowCertificateStatus = false; + internal bool expectSessionTicket = false; + internal TlsKeyExchange keyExchange = null; + internal TlsAuthentication authentication = null; + internal CertificateStatus certificateStatus = null; + internal CertificateRequest certificateRequest = null; + internal TlsCredentials clientCredentials = null; + } + } +} diff --git a/crypto/src/crypto/tls/DtlsEpoch.cs b/crypto/src/crypto/tls/DtlsEpoch.cs new file mode 100644 index 000000000..91fffa5e1 --- /dev/null +++ b/crypto/src/crypto/tls/DtlsEpoch.cs @@ -0,0 +1,51 @@ +using System; + +namespace Org.BouncyCastle.Crypto.Tls +{ + internal class DtlsEpoch + { + private readonly DtlsReplayWindow mReplayWindow = new DtlsReplayWindow(); + + private readonly int mEpoch; + private readonly TlsCipher mCipher; + + private long mSequenceNumber = 0; + + internal DtlsEpoch(int epoch, TlsCipher cipher) + { + if (epoch < 0) + throw new ArgumentException("must be >= 0", "epoch"); + if (cipher == null) + throw new ArgumentNullException("cipher"); + + this.mEpoch = epoch; + this.mCipher = cipher; + } + + internal long AllocateSequenceNumber() + { + // TODO Check for overflow + return mSequenceNumber++; + } + + internal TlsCipher Cipher + { + get { return mCipher; } + } + + internal int Epoch + { + get { return mEpoch; } + } + + internal DtlsReplayWindow ReplayWindow + { + get { return mReplayWindow; } + } + + internal long SequenceNumber + { + get { return mSequenceNumber; } + } + } +} diff --git a/crypto/src/crypto/tls/DtlsHandshakeRetransmit.cs b/crypto/src/crypto/tls/DtlsHandshakeRetransmit.cs new file mode 100644 index 000000000..8bfae78b1 --- /dev/null +++ b/crypto/src/crypto/tls/DtlsHandshakeRetransmit.cs @@ -0,0 +1,11 @@ +using System; +using System.IO; + +namespace Org.BouncyCastle.Crypto.Tls +{ + interface DtlsHandshakeRetransmit + { + /// <exception cref="IOException"/> + void ReceivedHandshakeRecord(int epoch, byte[] buf, int off, int len); + } +} diff --git a/crypto/src/crypto/tls/DtlsProtocol.cs b/crypto/src/crypto/tls/DtlsProtocol.cs new file mode 100644 index 000000000..6d62c5a90 --- /dev/null +++ b/crypto/src/crypto/tls/DtlsProtocol.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections; +using System.IO; + +using Org.BouncyCastle.Security; +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Crypto.Tls +{ + public abstract class DtlsProtocol + { + protected readonly SecureRandom mSecureRandom; + + protected DtlsProtocol(SecureRandom secureRandom) + { + if (secureRandom == null) + throw new ArgumentNullException("secureRandom"); + + this.mSecureRandom = secureRandom; + } + + /// <exception cref="IOException"/> + protected virtual void ProcessFinished(byte[] body, byte[] expected_verify_data) + { + MemoryStream buf = new MemoryStream(body, false); + + byte[] verify_data = TlsUtilities.ReadFully(expected_verify_data.Length, buf); + + TlsProtocol.AssertEmpty(buf); + + if (!Arrays.ConstantTimeAreEqual(expected_verify_data, verify_data)) + throw new TlsFatalAlert(AlertDescription.handshake_failure); + } + + /// <exception cref="IOException"/> + protected static short EvaluateMaxFragmentLengthExtension(IDictionary clientExtensions, IDictionary serverExtensions, + byte alertDescription) + { + short maxFragmentLength = TlsExtensionsUtilities.GetMaxFragmentLengthExtension(serverExtensions); + if (maxFragmentLength >= 0 && maxFragmentLength != TlsExtensionsUtilities.GetMaxFragmentLengthExtension(clientExtensions)) + throw new TlsFatalAlert(alertDescription); + return maxFragmentLength; + } + + /// <exception cref="IOException"/> + protected static byte[] GenerateCertificate(Certificate certificate) + { + MemoryStream buf = new MemoryStream(); + certificate.Encode(buf); + return buf.ToArray(); + } + + /// <exception cref="IOException"/> + protected static byte[] GenerateSupplementalData(IList supplementalData) + { + MemoryStream buf = new MemoryStream(); + TlsProtocol.WriteSupplementalData(buf, supplementalData); + return buf.ToArray(); + } + + /// <exception cref="IOException"/> + protected static void ValidateSelectedCipherSuite(int selectedCipherSuite, byte alertDescription) + { + switch (TlsUtilities.GetEncryptionAlgorithm(selectedCipherSuite)) + { + case EncryptionAlgorithm.RC4_40: + case EncryptionAlgorithm.RC4_128: + throw new TlsFatalAlert(alertDescription); + } + } + } +} diff --git a/crypto/src/crypto/tls/DtlsReassembler.cs b/crypto/src/crypto/tls/DtlsReassembler.cs new file mode 100644 index 000000000..11fe609cf --- /dev/null +++ b/crypto/src/crypto/tls/DtlsReassembler.cs @@ -0,0 +1,125 @@ +using System; +using System.Collections; + +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Crypto.Tls +{ + class DtlsReassembler + { + private readonly byte mMsgType; + private readonly byte[] mBody; + + private readonly IList mMissing = Platform.CreateArrayList(); + + internal DtlsReassembler(byte msg_type, int length) + { + this.mMsgType = msg_type; + this.mBody = new byte[length]; + this.mMissing.Add(new Range(0, length)); + } + + internal byte MsgType + { + get { return mMsgType; } + } + + internal byte[] GetBodyIfComplete() + { + return mMissing.Count == 0 ? mBody : null; + } + + internal void ContributeFragment(byte msg_type, int length, byte[] buf, int off, int fragment_offset, + int fragment_length) + { + int fragment_end = fragment_offset + fragment_length; + + if (this.mMsgType != msg_type || this.mBody.Length != length || fragment_end > length) + { + return; + } + + if (fragment_length == 0) + { + // NOTE: Empty messages still require an empty fragment to complete it + if (fragment_offset == 0 && mMissing.Count > 0) + { + Range firstRange = (Range)mMissing[0]; + if (firstRange.End == 0) + { + mMissing.RemoveAt(0); + } + } + return; + } + + for (int i = 0; i < mMissing.Count; ++i) + { + Range range = (Range)mMissing[i]; + if (range.Start >= fragment_end) + { + break; + } + if (range.End > fragment_offset) + { + + int copyStart = System.Math.Max(range.Start, fragment_offset); + int copyEnd = System.Math.Min(range.End, fragment_end); + int copyLength = copyEnd - copyStart; + + Array.Copy(buf, off + copyStart - fragment_offset, mBody, copyStart, + copyLength); + + if (copyStart == range.Start) + { + if (copyEnd == range.End) + { + mMissing.RemoveAt(i--); + } + else + { + range.Start = copyEnd; + } + } + else + { + if (copyEnd != range.End) + { + mMissing.Insert(++i, new Range(copyEnd, range.End)); + } + range.End = copyStart; + } + } + } + } + + internal void Reset() + { + this.mMissing.Clear(); + this.mMissing.Add(new Range(0, mBody.Length)); + } + + private class Range + { + private int mStart, mEnd; + + internal Range(int start, int end) + { + this.mStart = start; + this.mEnd = end; + } + + public int Start + { + get { return mStart; } + set { this.mStart = value; } + } + + public int End + { + get { return mEnd; } + set { this.mEnd = value; } + } + } + } +} diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs new file mode 100644 index 000000000..70befd9e4 --- /dev/null +++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs @@ -0,0 +1,507 @@ +using System; +using System.IO; + +using Org.BouncyCastle.Utilities.Date; + +namespace Org.BouncyCastle.Crypto.Tls +{ + internal class DtlsRecordLayer + : DatagramTransport + { + private const int RECORD_HEADER_LENGTH = 13; + private const int MAX_FRAGMENT_LENGTH = 1 << 14; + private const long TCP_MSL = 1000L * 60 * 2; + private const long RETRANSMIT_TIMEOUT = TCP_MSL * 2; + + private readonly DatagramTransport mTransport; + private readonly TlsContext mContext; + private readonly TlsPeer mPeer; + + private readonly ByteQueue mRecordQueue = new ByteQueue(); + + private volatile bool mClosed = false; + private volatile bool mFailed = false; + private volatile ProtocolVersion mDiscoveredPeerVersion = null; + private volatile bool mInHandshake; + private volatile int mPlaintextLimit; + private DtlsEpoch mCurrentEpoch, mPendingEpoch; + private DtlsEpoch mReadEpoch, mWriteEpoch; + + private DtlsHandshakeRetransmit mRetransmit = null; + private DtlsEpoch mRetransmitEpoch = null; + private long mRetransmitExpiry = 0; + + internal DtlsRecordLayer(DatagramTransport transport, TlsContext context, TlsPeer peer, byte contentType) + { + this.mTransport = transport; + this.mContext = context; + this.mPeer = peer; + + this.mInHandshake = true; + + this.mCurrentEpoch = new DtlsEpoch(0, new TlsNullCipher(context)); + this.mPendingEpoch = null; + this.mReadEpoch = mCurrentEpoch; + this.mWriteEpoch = mCurrentEpoch; + + SetPlaintextLimit(MAX_FRAGMENT_LENGTH); + } + + internal virtual void SetPlaintextLimit(int plaintextLimit) + { + this.mPlaintextLimit = plaintextLimit; + } + + internal virtual ProtocolVersion DiscoveredPeerVersion + { + get { return mDiscoveredPeerVersion; } + } + + internal virtual ProtocolVersion ResetDiscoveredPeerVersion() + { + ProtocolVersion result = mDiscoveredPeerVersion; + mDiscoveredPeerVersion = null; + return result; + } + + internal virtual void InitPendingEpoch(TlsCipher pendingCipher) + { + if (mPendingEpoch != null) + throw new InvalidOperationException(); + + /* + * TODO "In order to ensure that any given sequence/epoch pair is unique, implementations + * MUST NOT allow the same epoch value to be reused within two times the TCP maximum segment + * lifetime." + */ + + // TODO Check for overflow + this.mPendingEpoch = new DtlsEpoch(mWriteEpoch.Epoch + 1, pendingCipher); + } + + internal virtual void HandshakeSuccessful(DtlsHandshakeRetransmit retransmit) + { + if (mReadEpoch == mCurrentEpoch || mWriteEpoch == mCurrentEpoch) + { + // TODO + throw new InvalidOperationException(); + } + + if (retransmit != null) + { + this.mRetransmit = retransmit; + this.mRetransmitEpoch = mCurrentEpoch; + this.mRetransmitExpiry = DateTimeUtilities.CurrentUnixMs() + RETRANSMIT_TIMEOUT; + } + + this.mInHandshake = false; + this.mCurrentEpoch = mPendingEpoch; + this.mPendingEpoch = null; + } + + internal virtual void ResetWriteEpoch() + { + if (mRetransmitEpoch != null) + { + this.mWriteEpoch = mRetransmitEpoch; + } + else + { + this.mWriteEpoch = mCurrentEpoch; + } + } + + public virtual int GetReceiveLimit() + { + return System.Math.Min(this.mPlaintextLimit, + mReadEpoch.Cipher.GetPlaintextLimit(mTransport.GetReceiveLimit() - RECORD_HEADER_LENGTH)); + } + + public virtual int GetSendLimit() + { + return System.Math.Min(this.mPlaintextLimit, + mWriteEpoch.Cipher.GetPlaintextLimit(mTransport.GetSendLimit() - RECORD_HEADER_LENGTH)); + } + + public virtual int Receive(byte[] buf, int off, int len, int waitMillis) + { + byte[] record = null; + + for (;;) + { + int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH; + if (record == null || record.Length < receiveLimit) + { + record = new byte[receiveLimit]; + } + + try + { + if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry) + { + mRetransmit = null; + mRetransmitEpoch = null; + } + + int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); + if (received < 0) + { + return received; + } + if (received < RECORD_HEADER_LENGTH) + { + continue; + } + int length = TlsUtilities.ReadUint16(record, 11); + if (received != (length + RECORD_HEADER_LENGTH)) + { + continue; + } + + byte type = TlsUtilities.ReadUint8(record, 0); + + // TODO Support user-specified custom protocols? + switch (type) + { + case ContentType.alert: + case ContentType.application_data: + case ContentType.change_cipher_spec: + case ContentType.handshake: + case ContentType.heartbeat: + break; + default: + // TODO Exception? + continue; + } + + int epoch = TlsUtilities.ReadUint16(record, 3); + + DtlsEpoch recordEpoch = null; + if (epoch == mReadEpoch.Epoch) + { + recordEpoch = mReadEpoch; + } + else if (type == ContentType.handshake && mRetransmitEpoch != null + && epoch == mRetransmitEpoch.Epoch) + { + recordEpoch = mRetransmitEpoch; + } + + if (recordEpoch == null) + { + continue; + } + + long seq = TlsUtilities.ReadUint48(record, 5); + if (recordEpoch.ReplayWindow.ShouldDiscard(seq)) + { + continue; + } + + ProtocolVersion version = TlsUtilities.ReadVersion(record, 1); + if (mDiscoveredPeerVersion != null && !mDiscoveredPeerVersion.Equals(version)) + { + continue; + } + + byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext( + GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH, + received - RECORD_HEADER_LENGTH); + + recordEpoch.ReplayWindow.ReportAuthenticated(seq); + + if (plaintext.Length > this.mPlaintextLimit) + { + continue; + } + + if (mDiscoveredPeerVersion == null) + { + mDiscoveredPeerVersion = version; + } + + switch (type) + { + case ContentType.alert: + { + if (plaintext.Length == 2) + { + byte alertLevel = plaintext[0]; + byte alertDescription = plaintext[1]; + + mPeer.NotifyAlertReceived(alertLevel, alertDescription); + + if (alertLevel == AlertLevel.fatal) + { + Fail(alertDescription); + throw new TlsFatalAlert(alertDescription); + } + + // TODO Can close_notify be a fatal alert? + if (alertDescription == AlertDescription.close_notify) + { + CloseTransport(); + } + } + + continue; + } + case ContentType.application_data: + { + if (mInHandshake) + { + // TODO Consider buffering application data for new epoch that arrives + // out-of-order with the Finished message + continue; + } + break; + } + case ContentType.change_cipher_spec: + { + // Implicitly receive change_cipher_spec and change to pending cipher state + + for (int i = 0; i < plaintext.Length; ++i) + { + byte message = TlsUtilities.ReadUint8(plaintext, i); + if (message != ChangeCipherSpec.change_cipher_spec) + { + continue; + } + + if (mPendingEpoch != null) + { + mReadEpoch = mPendingEpoch; + } + } + + continue; + } + case ContentType.handshake: + { + if (!mInHandshake) + { + if (mRetransmit != null) + { + mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length); + } + + // TODO Consider support for HelloRequest + continue; + } + break; + } + case ContentType.heartbeat: + { + // TODO[RFC 6520] + continue; + } + } + + /* + * NOTE: If we receive any non-handshake data in the new epoch implies the peer has + * received our final flight. + */ + if (!mInHandshake && mRetransmit != null) + { + this.mRetransmit = null; + this.mRetransmitEpoch = null; + } + + Array.Copy(plaintext, 0, buf, off, plaintext.Length); + return plaintext.Length; + } + catch (IOException e) + { + // NOTE: Assume this is a timeout for the moment + throw e; + } + } + } + + /// <exception cref="IOException"/> + public virtual void Send(byte[] buf, int off, int len) + { + byte contentType = ContentType.application_data; + + if (this.mInHandshake || this.mWriteEpoch == this.mRetransmitEpoch) + { + contentType = ContentType.handshake; + + byte handshakeType = TlsUtilities.ReadUint8(buf, off); + if (handshakeType == HandshakeType.finished) + { + DtlsEpoch nextEpoch = null; + if (this.mInHandshake) + { + nextEpoch = mPendingEpoch; + } + else if (this.mWriteEpoch == this.mRetransmitEpoch) + { + nextEpoch = mCurrentEpoch; + } + + if (nextEpoch == null) + { + // TODO + throw new InvalidOperationException(); + } + + // Implicitly send change_cipher_spec and change to pending cipher state + + // TODO Send change_cipher_spec and finished records in single datagram? + byte[] data = new byte[]{ 1 }; + SendRecord(ContentType.change_cipher_spec, data, 0, data.Length); + + mWriteEpoch = nextEpoch; + } + } + + SendRecord(contentType, buf, off, len); + } + + public virtual void Close() + { + if (!mClosed) + { + if (mInHandshake) + { + Warn(AlertDescription.user_canceled, "User canceled handshake"); + } + CloseTransport(); + } + } + + internal virtual void Fail(byte alertDescription) + { + if (!mClosed) + { + try + { + RaiseAlert(AlertLevel.fatal, alertDescription, null, null); + } + catch (Exception) + { + // Ignore + } + + mFailed = true; + + CloseTransport(); + } + } + + internal virtual void Warn(byte alertDescription, string message) + { + RaiseAlert(AlertLevel.warning, alertDescription, message, null); + } + + private void CloseTransport() + { + if (!mClosed) + { + /* + * RFC 5246 7.2.1. Unless some other fatal alert has been transmitted, each party is + * required to send a close_notify alert before closing the write side of the + * connection. The other party MUST respond with a close_notify alert of its own and + * close down the connection immediately, discarding any pending writes. + */ + + try + { + if (!mFailed) + { + Warn(AlertDescription.close_notify, null); + } + mTransport.Close(); + } + catch (Exception) + { + // Ignore + } + + mClosed = true; + } + } + + private void RaiseAlert(byte alertLevel, byte alertDescription, string message, Exception cause) + { + mPeer.NotifyAlertRaised(alertLevel, alertDescription, message, cause); + + byte[] error = new byte[2]; + error[0] = (byte)alertLevel; + error[1] = (byte)alertDescription; + + SendRecord(ContentType.alert, error, 0, 2); + } + + private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis) + { + if (mRecordQueue.Available > 0) + { + int length = 0; + if (mRecordQueue.Available >= RECORD_HEADER_LENGTH) + { + byte[] lengthBytes = new byte[2]; + mRecordQueue.Read(lengthBytes, 0, 2, 11); + length = TlsUtilities.ReadUint16(lengthBytes, 0); + } + + int received = System.Math.Min(mRecordQueue.Available, RECORD_HEADER_LENGTH + length); + mRecordQueue.RemoveData(buf, off, received, 0); + return received; + } + + { + int received = mTransport.Receive(buf, off, len, waitMillis); + if (received >= RECORD_HEADER_LENGTH) + { + int fragmentLength = TlsUtilities.ReadUint16(buf, off + 11); + int recordLength = RECORD_HEADER_LENGTH + fragmentLength; + if (received > recordLength) + { + mRecordQueue.AddData(buf, off + recordLength, received - recordLength); + received = recordLength; + } + } + return received; + } + } + + private void SendRecord(byte contentType, byte[] buf, int off, int len) + { + if (len > this.mPlaintextLimit) + throw new TlsFatalAlert(AlertDescription.internal_error); + + /* + * RFC 5264 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert, + * or ChangeCipherSpec content types. + */ + if (len < 1 && contentType != ContentType.application_data) + throw new TlsFatalAlert(AlertDescription.internal_error); + + int recordEpoch = mWriteEpoch.Epoch; + long recordSequenceNumber = mWriteEpoch.AllocateSequenceNumber(); + + byte[] ciphertext = mWriteEpoch.Cipher.EncodePlaintext( + GetMacSequenceNumber(recordEpoch, recordSequenceNumber), contentType, buf, off, len); + + // TODO Check the ciphertext length? + + byte[] record = new byte[ciphertext.Length + RECORD_HEADER_LENGTH]; + TlsUtilities.WriteUint8(contentType, record, 0); + ProtocolVersion version = mDiscoveredPeerVersion != null ? mDiscoveredPeerVersion : mContext.ClientVersion; + TlsUtilities.WriteVersion(version, record, 1); + TlsUtilities.WriteUint16(recordEpoch, record, 3); + TlsUtilities.WriteUint48(recordSequenceNumber, record, 5); + TlsUtilities.WriteUint16(ciphertext.Length, record, 11); + Array.Copy(ciphertext, 0, record, RECORD_HEADER_LENGTH, ciphertext.Length); + + mTransport.Send(record, 0, record.Length); + } + + private static long GetMacSequenceNumber(int epoch, long sequence_number) + { + return ((epoch & 0xFFFFFFFFL) << 48) | sequence_number; + } + } +} diff --git a/crypto/src/crypto/tls/DtlsReliableHandshake.cs b/crypto/src/crypto/tls/DtlsReliableHandshake.cs new file mode 100644 index 000000000..bf9e61d03 --- /dev/null +++ b/crypto/src/crypto/tls/DtlsReliableHandshake.cs @@ -0,0 +1,443 @@ +using System; +using System.Collections; +using System.IO; + +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Crypto.Tls +{ + internal class DtlsReliableHandshake + { + private const int MAX_RECEIVE_AHEAD = 10; + + private readonly DtlsRecordLayer mRecordLayer; + + private TlsHandshakeHash mHandshakeHash; + + private IDictionary mCurrentInboundFlight = Platform.CreateHashtable(); + private IDictionary mPreviousInboundFlight = null; + private IList mOutboundFlight = Platform.CreateArrayList(); + private bool mSending = true; + + private int mMessageSeq = 0, mNextReceiveSeq = 0; + + internal DtlsReliableHandshake(TlsContext context, DtlsRecordLayer transport) + { + this.mRecordLayer = transport; + this.mHandshakeHash = new DeferredHash(); + this.mHandshakeHash.Init(context); + } + + internal void NotifyHelloComplete() + { + this.mHandshakeHash = mHandshakeHash.NotifyPrfDetermined(); + } + + internal TlsHandshakeHash HandshakeHash + { + get { return mHandshakeHash; } + } + + internal TlsHandshakeHash PrepareToFinish() + { + TlsHandshakeHash result = mHandshakeHash; + this.mHandshakeHash = mHandshakeHash.StopTracking(); + return result; + } + + internal void SendMessage(byte msg_type, byte[] body) + { + TlsUtilities.CheckUint24(body.Length); + + if (!mSending) + { + CheckInboundFlight(); + mSending = true; + mOutboundFlight.Clear(); + } + + Message message = new Message(mMessageSeq++, msg_type, body); + + mOutboundFlight.Add(message); + + WriteMessage(message); + UpdateHandshakeMessagesDigest(message); + } + + internal byte[] ReceiveMessageBody(byte msg_type) + { + Message message = ReceiveMessage(); + if (message.Type != msg_type) + throw new TlsFatalAlert(AlertDescription.unexpected_message); + + return message.Body; + } + + internal Message ReceiveMessage() + { + if (mSending) + { + mSending = false; + PrepareInboundFlight(); + } + + // Check if we already have the next message waiting + { + DtlsReassembler next = (DtlsReassembler)mCurrentInboundFlight[mNextReceiveSeq]; + if (next != null) + { + byte[] body = next.GetBodyIfComplete(); + if (body != null) + { + mPreviousInboundFlight = null; + return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++, next.MsgType, body)); + } + } + } + + byte[] buf = null; + + // TODO Check the conditions under which we should reset this + int readTimeoutMillis = 1000; + + for (;;) + { + int receiveLimit = mRecordLayer.GetReceiveLimit(); + if (buf == null || buf.Length < receiveLimit) + { + buf = new byte[receiveLimit]; + } + + // TODO Handle records containing multiple handshake messages + + try + { + for (; ; ) + { + int Received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis); + if (Received < 0) + { + break; + } + if (Received < 12) + { + continue; + } + int fragment_length = TlsUtilities.ReadUint24(buf, 9); + if (Received != (fragment_length + 12)) + { + continue; + } + int seq = TlsUtilities.ReadUint16(buf, 4); + if (seq > (mNextReceiveSeq + MAX_RECEIVE_AHEAD)) + { + continue; + } + byte msg_type = TlsUtilities.ReadUint8(buf, 0); + int length = TlsUtilities.ReadUint24(buf, 1); + int fragment_offset = TlsUtilities.ReadUint24(buf, 6); + if (fragment_offset + fragment_length > length) + { + continue; + } + + if (seq < mNextReceiveSeq) + { + /* + * NOTE: If we Receive the previous flight of incoming messages in full + * again, retransmit our last flight + */ + if (mPreviousInboundFlight != null) + { + DtlsReassembler reassembler = (DtlsReassembler)mPreviousInboundFlight[seq]; + if (reassembler != null) + { + reassembler.ContributeFragment(msg_type, length, buf, 12, fragment_offset, + fragment_length); + + if (CheckAll(mPreviousInboundFlight)) + { + ResendOutboundFlight(); + + /* + * TODO[DTLS] implementations SHOULD back off handshake packet + * size during the retransmit backoff. + */ + readTimeoutMillis = System.Math.Min(readTimeoutMillis * 2, 60000); + + ResetAll(mPreviousInboundFlight); + } + } + } + } + else + { + DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[seq]; + if (reassembler == null) + { + reassembler = new DtlsReassembler(msg_type, length); + mCurrentInboundFlight[seq] = reassembler; + } + + reassembler.ContributeFragment(msg_type, length, buf, 12, fragment_offset, fragment_length); + + if (seq == mNextReceiveSeq) + { + byte[] body = reassembler.GetBodyIfComplete(); + if (body != null) + { + mPreviousInboundFlight = null; + return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++, + reassembler.MsgType, body)); + } + } + } + } + } + catch (IOException) + { + // NOTE: Assume this is a timeout for the moment + } + + ResendOutboundFlight(); + + /* + * TODO[DTLS] implementations SHOULD back off handshake packet size during the + * retransmit backoff. + */ + readTimeoutMillis = System.Math.Min(readTimeoutMillis * 2, 60000); + } + } + + internal void Finish() + { + DtlsHandshakeRetransmit retransmit = null; + if (!mSending) + { + CheckInboundFlight(); + } + else if (mCurrentInboundFlight != null) + { + /* + * RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP], + * when in the FINISHED state, the node that transmits the last flight (the server in an + * ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit + * of the peer's last flight with a retransmit of the last flight. + */ + retransmit = new Retransmit(this); + } + + mRecordLayer.HandshakeSuccessful(retransmit); + } + + internal void ResetHandshakeMessagesDigest() + { + mHandshakeHash.Reset(); + } + + private void HandleRetransmittedHandshakeRecord(int epoch, byte[] buf, int off, int len) + { + /* + * TODO Need to handle the case where the previous inbound flight contains + * messages from two epochs. + */ + if (len < 12) + return; + int fragment_length = TlsUtilities.ReadUint24(buf, off + 9); + if (len != (fragment_length + 12)) + return; + int seq = TlsUtilities.ReadUint16(buf, off + 4); + if (seq >= mNextReceiveSeq) + return; + + byte msg_type = TlsUtilities.ReadUint8(buf, off); + + // TODO This is a hack that only works until we try to support renegotiation + int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0; + if (epoch != expectedEpoch) + return; + + int length = TlsUtilities.ReadUint24(buf, off + 1); + int fragment_offset = TlsUtilities.ReadUint24(buf, off + 6); + if (fragment_offset + fragment_length > length) + return; + + DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[seq]; + if (reassembler != null) + { + reassembler.ContributeFragment(msg_type, length, buf, off + 12, fragment_offset, + fragment_length); + if (CheckAll(mCurrentInboundFlight)) + { + ResendOutboundFlight(); + ResetAll(mCurrentInboundFlight); + } + } + } + + /** + * Check that there are no "extra" messages left in the current inbound flight + */ + private void CheckInboundFlight() + { + foreach (int key in mCurrentInboundFlight.Keys) + { + if (key >= mNextReceiveSeq) + { + // TODO Should this be considered an error? + } + } + } + + private void PrepareInboundFlight() + { + ResetAll(mCurrentInboundFlight); + mPreviousInboundFlight = mCurrentInboundFlight; + mCurrentInboundFlight = Platform.CreateHashtable(); + } + + private void ResendOutboundFlight() + { + mRecordLayer.ResetWriteEpoch(); + for (int i = 0; i < mOutboundFlight.Count; ++i) + { + WriteMessage((Message)mOutboundFlight[i]); + } + } + + private Message UpdateHandshakeMessagesDigest(Message message) + { + if (message.Type != HandshakeType.hello_request) + { + byte[] body = message.Body; + byte[] buf = new byte[12]; + TlsUtilities.WriteUint8(message.Type, buf, 0); + TlsUtilities.WriteUint24(body.Length, buf, 1); + TlsUtilities.WriteUint16(message.Seq, buf, 4); + TlsUtilities.WriteUint24(0, buf, 6); + TlsUtilities.WriteUint24(body.Length, buf, 9); + mHandshakeHash.BlockUpdate(buf, 0, buf.Length); + mHandshakeHash.BlockUpdate(body, 0, body.Length); + } + return message; + } + + private void WriteMessage(Message message) + { + int sendLimit = mRecordLayer.GetSendLimit(); + int fragmentLimit = sendLimit - 12; + + // TODO Support a higher minimum fragment size? + if (fragmentLimit < 1) + { + // TODO Should we be throwing an exception here? + throw new TlsFatalAlert(AlertDescription.internal_error); + } + + int length = message.Body.Length; + + // NOTE: Must still send a fragment if body is empty + int fragment_offset = 0; + do + { + int fragment_length = System.Math.Min(length - fragment_offset, fragmentLimit); + WriteHandshakeFragment(message, fragment_offset, fragment_length); + fragment_offset += fragment_length; + } + while (fragment_offset < length); + } + + private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length) + { + RecordLayerBuffer fragment = new RecordLayerBuffer(12 + fragment_length); + TlsUtilities.WriteUint8(message.Type, fragment); + TlsUtilities.WriteUint24(message.Body.Length, fragment); + TlsUtilities.WriteUint16(message.Seq, fragment); + TlsUtilities.WriteUint24(fragment_offset, fragment); + TlsUtilities.WriteUint24(fragment_length, fragment); + fragment.Write(message.Body, fragment_offset, fragment_length); + + fragment.SendToRecordLayer(mRecordLayer); + } + + private static bool CheckAll(IDictionary inboundFlight) + { + foreach (DtlsReassembler r in inboundFlight.Values) + { + if (r.GetBodyIfComplete() == null) + { + return false; + } + } + return true; + } + + private static void ResetAll(IDictionary inboundFlight) + { + foreach (DtlsReassembler r in inboundFlight.Values) + { + r.Reset(); + } + } + + internal class Message + { + private readonly int mMessageSeq; + private readonly byte mMsgType; + private readonly byte[] mBody; + + internal Message(int message_seq, byte msg_type, byte[] body) + { + this.mMessageSeq = message_seq; + this.mMsgType = msg_type; + this.mBody = body; + } + + public int Seq + { + get { return mMessageSeq; } + } + + public byte Type + { + get { return mMsgType; } + } + + public byte[] Body + { + get { return mBody; } + } + } + + internal class RecordLayerBuffer + : MemoryStream + { + internal RecordLayerBuffer(int size) + : base(size) + { + } + + internal void SendToRecordLayer(DtlsRecordLayer recordLayer) + { + recordLayer.Send(GetBuffer(), 0, (int)Length); + this.Close(); + } + } + + internal class Retransmit + : DtlsHandshakeRetransmit + { + private readonly DtlsReliableHandshake mOuter; + + internal Retransmit(DtlsReliableHandshake outer) + { + this.mOuter = outer; + } + + public void ReceivedHandshakeRecord(int epoch, byte[] buf, int off, int len) + { + mOuter.HandleRetransmittedHandshakeRecord(epoch, buf, off, len); + } + } + } +} diff --git a/crypto/src/crypto/tls/DtlsReplayWindow.cs b/crypto/src/crypto/tls/DtlsReplayWindow.cs new file mode 100644 index 000000000..ea18e805e --- /dev/null +++ b/crypto/src/crypto/tls/DtlsReplayWindow.cs @@ -0,0 +1,85 @@ +using System; + +namespace Org.BouncyCastle.Crypto.Tls +{ + /** + * RFC 4347 4.1.2.5 Anti-replay + * <p/> + * Support fast rejection of duplicate records by maintaining a sliding receive window + */ + internal class DtlsReplayWindow + { + private const long VALID_SEQ_MASK = 0x0000FFFFFFFFFFFFL; + + private const long WINDOW_SIZE = 64L; + + private long mLatestConfirmedSeq = -1; + private long mBitmap = 0; + + /** + * Check whether a received record with the given sequence number should be rejected as a duplicate. + * + * @param seq the 48-bit DTLSPlainText.sequence_number field of a received record. + * @return true if the record should be discarded without further processing. + */ + internal bool ShouldDiscard(long seq) + { + if ((seq & VALID_SEQ_MASK) != seq) + return true; + + if (seq <= mLatestConfirmedSeq) + { + long diff = mLatestConfirmedSeq - seq; + if (diff >= WINDOW_SIZE) + return true; + if ((mBitmap & (1L << (int)diff)) != 0) + return true; + } + + return false; + } + + /** + * Report that a received record with the given sequence number passed authentication checks. + * + * @param seq the 48-bit DTLSPlainText.sequence_number field of an authenticated record. + */ + internal void ReportAuthenticated(long seq) + { + if ((seq & VALID_SEQ_MASK) != seq) + throw new ArgumentException("out of range", "seq"); + + if (seq <= mLatestConfirmedSeq) + { + long diff = mLatestConfirmedSeq - seq; + if (diff < WINDOW_SIZE) + { + mBitmap |= (1L << (int)diff); + } + } + else + { + long diff = seq - mLatestConfirmedSeq; + if (diff >= WINDOW_SIZE) + { + mBitmap = 1; + } + else + { + mBitmap <<= (int)diff; + mBitmap |= 1; + } + mLatestConfirmedSeq = seq; + } + } + + /** + * When a new epoch begins, sequence numbers begin again at 0 + */ + internal void Reset() + { + mLatestConfirmedSeq = -1; + mBitmap = 0; + } + } +} diff --git a/crypto/src/crypto/tls/DtlsServerProtocol.cs b/crypto/src/crypto/tls/DtlsServerProtocol.cs new file mode 100644 index 000000000..3335a9f36 --- /dev/null +++ b/crypto/src/crypto/tls/DtlsServerProtocol.cs @@ -0,0 +1,642 @@ +using System; +using System.Collections; +using System.IO; + +using Org.BouncyCastle.Asn1.X509; +using Org.BouncyCastle.Crypto.Parameters; +using Org.BouncyCastle.Security; +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Crypto.Tls +{ + public class DtlsServerProtocol + : DtlsProtocol + { + protected bool mVerifyRequests = true; + + public DtlsServerProtocol(SecureRandom secureRandom) + : base(secureRandom) + { + } + + public virtual bool VerifyRequests + { + get { return mVerifyRequests; } + set { this.mVerifyRequests = value; } + } + + public virtual DtlsTransport Accept(TlsServer server, DatagramTransport transport) + { + if (server == null) + throw new ArgumentNullException("server"); + if (transport == null) + throw new ArgumentNullException("transport"); + + SecurityParameters securityParameters = new SecurityParameters(); + securityParameters.entity = ConnectionEnd.server; + + ServerHandshakeState state = new ServerHandshakeState(); + state.server = server; + state.serverContext = new TlsServerContextImpl(mSecureRandom, securityParameters); + + securityParameters.serverRandom = TlsProtocol.CreateRandomBlock(server.ShouldUseGmtUnixTime(), + state.serverContext.NonceRandomGenerator); + + server.Init(state.serverContext); + + DtlsRecordLayer recordLayer = new DtlsRecordLayer(transport, state.serverContext, server, ContentType.handshake); + + // TODO Need to handle sending of HelloVerifyRequest without entering a full connection + + try + { + return ServerHandshake(state, recordLayer); + } + catch (TlsFatalAlert fatalAlert) + { + recordLayer.Fail(fatalAlert.AlertDescription); + throw fatalAlert; + } + catch (IOException e) + { + recordLayer.Fail(AlertDescription.internal_error); + throw e; + } + catch (Exception e) + { + recordLayer.Fail(AlertDescription.internal_error); + throw new TlsFatalAlert(AlertDescription.internal_error, e); + } + } + + internal virtual DtlsTransport ServerHandshake(ServerHandshakeState state, DtlsRecordLayer recordLayer) + { + SecurityParameters securityParameters = state.serverContext.SecurityParameters; + DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.serverContext, recordLayer); + + DtlsReliableHandshake.Message clientMessage = handshake.ReceiveMessage(); + + { + // NOTE: After receiving a record from the client, we discover the record layer version + ProtocolVersion client_version = recordLayer.DiscoveredPeerVersion; + // TODO Read RFCs for guidance on the expected record layer version number + state.serverContext.SetClientVersion(client_version); + } + + if (clientMessage.Type == HandshakeType.client_hello) + { + ProcessClientHello(state, clientMessage.Body); + } + else + { + throw new TlsFatalAlert(AlertDescription.unexpected_message); + } + + { + byte[] serverHelloBody = GenerateServerHello(state); + + if (state.maxFragmentLength >= 0) + { + int plainTextLimit = 1 << (8 + state.maxFragmentLength); + recordLayer.SetPlaintextLimit(plainTextLimit); + } + + securityParameters.cipherSuite = state.selectedCipherSuite; + securityParameters.compressionAlgorithm = (byte)state.selectedCompressionMethod; + securityParameters.prfAlgorithm = TlsProtocol.GetPrfAlgorithm(state.serverContext, + state.selectedCipherSuite); + + /* + * RFC 5264 7.4.9. Any cipher suite which does not explicitly specify verify_data_length + * has a verify_data_length equal to 12. This includes all existing cipher suites. + */ + securityParameters.verifyDataLength = 12; + + handshake.SendMessage(HandshakeType.server_hello, serverHelloBody); + } + + handshake.NotifyHelloComplete(); + + IList serverSupplementalData = state.server.GetServerSupplementalData(); + if (serverSupplementalData != null) + { + byte[] supplementalDataBody = GenerateSupplementalData(serverSupplementalData); + handshake.SendMessage(HandshakeType.supplemental_data, supplementalDataBody); + } + + state.keyExchange = state.server.GetKeyExchange(); + state.keyExchange.Init(state.serverContext); + + state.serverCredentials = state.server.GetCredentials(); + + Certificate serverCertificate = null; + + if (state.serverCredentials == null) + { + state.keyExchange.SkipServerCredentials(); + } + else + { + state.keyExchange.ProcessServerCredentials(state.serverCredentials); + + serverCertificate = state.serverCredentials.Certificate; + byte[] certificateBody = GenerateCertificate(serverCertificate); + handshake.SendMessage(HandshakeType.certificate, certificateBody); + } + + // TODO[RFC 3546] Check whether empty certificates is possible, allowed, or excludes CertificateStatus + if (serverCertificate == null || serverCertificate.IsEmpty) + { + state.allowCertificateStatus = false; + } + + if (state.allowCertificateStatus) + { + CertificateStatus certificateStatus = state.server.GetCertificateStatus(); + if (certificateStatus != null) + { + byte[] certificateStatusBody = GenerateCertificateStatus(state, certificateStatus); + handshake.SendMessage(HandshakeType.certificate_status, certificateStatusBody); + } + } + + byte[] serverKeyExchange = state.keyExchange.GenerateServerKeyExchange(); + if (serverKeyExchange != null) + { + handshake.SendMessage(HandshakeType.server_key_exchange, serverKeyExchange); + } + + if (state.serverCredentials != null) + { + state.certificateRequest = state.server.GetCertificateRequest(); + if (state.certificateRequest != null) + { + state.keyExchange.ValidateCertificateRequest(state.certificateRequest); + + byte[] certificateRequestBody = GenerateCertificateRequest(state, state.certificateRequest); + handshake.SendMessage(HandshakeType.certificate_request, certificateRequestBody); + + TlsUtilities.TrackHashAlgorithms(handshake.HandshakeHash, + state.certificateRequest.SupportedSignatureAlgorithms); + } + } + + handshake.SendMessage(HandshakeType.server_hello_done, TlsUtilities.EmptyBytes); + + handshake.HandshakeHash.SealHashAlgorithms(); + + clientMessage = handshake.ReceiveMessage(); + + if (clientMessage.Type == HandshakeType.supplemental_data) + { + ProcessClientSupplementalData(state, clientMessage.Body); + clientMessage = handshake.ReceiveMessage(); + } + else + { + state.server.ProcessClientSupplementalData(null); + } + + if (state.certificateRequest == null) + { + state.keyExchange.SkipClientCredentials(); + } + else + { + if (clientMessage.Type == HandshakeType.certificate) + { + ProcessClientCertificate(state, clientMessage.Body); + clientMessage = handshake.ReceiveMessage(); + } + else + { + if (TlsUtilities.IsTlsV12(state.serverContext)) + { + /* + * RFC 5246 If no suitable certificate is available, the client MUST send a + * certificate message containing no certificates. + * + * NOTE: In previous RFCs, this was SHOULD instead of MUST. + */ + throw new TlsFatalAlert(AlertDescription.unexpected_message); + } + + NotifyClientCertificate(state, Certificate.EmptyChain); + } + } + + if (clientMessage.Type == HandshakeType.client_key_exchange) + { + ProcessClientKeyExchange(state, clientMessage.Body); + } + else + { + throw new TlsFatalAlert(AlertDescription.unexpected_message); + } + + TlsHandshakeHash prepareFinishHash = handshake.PrepareToFinish(); + securityParameters.sessionHash = TlsProtocol.GetCurrentPrfHash(state.serverContext, prepareFinishHash, null); + + TlsProtocol.EstablishMasterSecret(state.serverContext, state.keyExchange); + recordLayer.InitPendingEpoch(state.server.GetCipher()); + + /* + * RFC 5246 7.4.8 This message is only sent following a client certificate that has signing + * capability (i.e., all certificates except those containing fixed Diffie-Hellman + * parameters). + */ + if (ExpectCertificateVerifyMessage(state)) + { + byte[] certificateVerifyBody = handshake.ReceiveMessageBody(HandshakeType.certificate_verify); + ProcessCertificateVerify(state, certificateVerifyBody, prepareFinishHash); + } + + // NOTE: Calculated exclusive of the actual Finished message from the client + byte[] expectedClientVerifyData = TlsUtilities.CalculateVerifyData(state.serverContext, ExporterLabel.client_finished, + TlsProtocol.GetCurrentPrfHash(state.serverContext, handshake.HandshakeHash, null)); + ProcessFinished(handshake.ReceiveMessageBody(HandshakeType.finished), expectedClientVerifyData); + + if (state.expectSessionTicket) + { + NewSessionTicket newSessionTicket = state.server.GetNewSessionTicket(); + byte[] newSessionTicketBody = GenerateNewSessionTicket(state, newSessionTicket); + handshake.SendMessage(HandshakeType.session_ticket, newSessionTicketBody); + } + + // NOTE: Calculated exclusive of the Finished message itself + byte[] serverVerifyData = TlsUtilities.CalculateVerifyData(state.serverContext, ExporterLabel.server_finished, + TlsProtocol.GetCurrentPrfHash(state.serverContext, handshake.HandshakeHash, null)); + handshake.SendMessage(HandshakeType.finished, serverVerifyData); + + handshake.Finish(); + + state.server.NotifyHandshakeComplete(); + + return new DtlsTransport(recordLayer); + } + + protected virtual byte[] GenerateCertificateRequest(ServerHandshakeState state, CertificateRequest certificateRequest) + { + MemoryStream buf = new MemoryStream(); + certificateRequest.Encode(buf); + return buf.ToArray(); + } + + protected virtual byte[] GenerateCertificateStatus(ServerHandshakeState state, CertificateStatus certificateStatus) + { + MemoryStream buf = new MemoryStream(); + certificateStatus.Encode(buf); + return buf.ToArray(); + } + + protected virtual byte[] GenerateNewSessionTicket(ServerHandshakeState state, NewSessionTicket newSessionTicket) + { + MemoryStream buf = new MemoryStream(); + newSessionTicket.Encode(buf); + return buf.ToArray(); + } + + protected virtual byte[] GenerateServerHello(ServerHandshakeState state) + { + SecurityParameters securityParameters = state.serverContext.SecurityParameters; + + MemoryStream buf = new MemoryStream(); + + ProtocolVersion server_version = state.server.GetServerVersion(); + if (!server_version.IsEqualOrEarlierVersionOf(state.serverContext.ClientVersion)) + throw new TlsFatalAlert(AlertDescription.internal_error); + + // TODO Read RFCs for guidance on the expected record layer version number + // recordStream.setReadVersion(server_version); + // recordStream.setWriteVersion(server_version); + // recordStream.setRestrictReadVersion(true); + state.serverContext.SetServerVersion(server_version); + + TlsUtilities.WriteVersion(state.serverContext.ServerVersion, buf); + + buf.Write(securityParameters.ServerRandom, 0, securityParameters.ServerRandom.Length); + + /* + * The server may return an empty session_id to indicate that the session will not be cached + * and therefore cannot be resumed. + */ + TlsUtilities.WriteOpaque8(TlsUtilities.EmptyBytes, buf); + + state.selectedCipherSuite = state.server.GetSelectedCipherSuite(); + if (!Arrays.Contains(state.offeredCipherSuites, state.selectedCipherSuite) + || state.selectedCipherSuite == CipherSuite.TLS_NULL_WITH_NULL_NULL + || CipherSuite.IsScsv(state.selectedCipherSuite) + || !TlsUtilities.IsValidCipherSuiteForVersion(state.selectedCipherSuite, server_version)) + { + throw new TlsFatalAlert(AlertDescription.internal_error); + } + + ValidateSelectedCipherSuite(state.selectedCipherSuite, AlertDescription.internal_error); + + state.selectedCompressionMethod = state.server.GetSelectedCompressionMethod(); + if (!Arrays.Contains(state.offeredCompressionMethods, (byte)state.selectedCompressionMethod)) + throw new TlsFatalAlert(AlertDescription.internal_error); + + TlsUtilities.WriteUint16(state.selectedCipherSuite, buf); + TlsUtilities.WriteUint8((byte)state.selectedCompressionMethod, buf); + + state.serverExtensions = state.server.GetServerExtensions(); + + /* + * RFC 5746 3.6. Server Behavior: Initial Handshake + */ + if (state.secure_renegotiation) + { + byte[] renegExtData = TlsUtilities.GetExtensionData(state.serverExtensions, ExtensionType.renegotiation_info); + bool noRenegExt = (null == renegExtData); + + if (noRenegExt) + { + /* + * Note that sending a "renegotiation_info" extension in response to a ClientHello + * containing only the SCSV is an explicit exception to the prohibition in RFC 5246, + * Section 7.4.1.4, on the server sending unsolicited extensions and is only allowed + * because the client is signaling its willingness to receive the extension via the + * TLS_EMPTY_RENEGOTIATION_INFO_SCSV SCSV. + */ + + /* + * If the secure_renegotiation flag is set to TRUE, the server MUST include an empty + * "renegotiation_info" extension in the ServerHello message. + */ + state.serverExtensions = TlsExtensionsUtilities.EnsureExtensionsInitialised(state.serverExtensions); + state.serverExtensions[ExtensionType.renegotiation_info] = TlsProtocol.CreateRenegotiationInfo(TlsUtilities.EmptyBytes); + } + } + + if (securityParameters.extendedMasterSecret) + { + state.serverExtensions = TlsExtensionsUtilities.EnsureExtensionsInitialised(state.serverExtensions); + TlsExtensionsUtilities.AddExtendedMasterSecretExtension(state.serverExtensions); + } + + if (state.serverExtensions != null) + { + securityParameters.encryptThenMac = TlsExtensionsUtilities.HasEncryptThenMacExtension(state.serverExtensions); + + state.maxFragmentLength = EvaluateMaxFragmentLengthExtension(state.clientExtensions, state.serverExtensions, + AlertDescription.internal_error); + + securityParameters.truncatedHMac = TlsExtensionsUtilities.HasTruncatedHMacExtension(state.serverExtensions); + + state.allowCertificateStatus = TlsUtilities.HasExpectedEmptyExtensionData(state.serverExtensions, + ExtensionType.status_request, AlertDescription.internal_error); + + state.expectSessionTicket = TlsUtilities.HasExpectedEmptyExtensionData(state.serverExtensions, + ExtensionType.session_ticket, AlertDescription.internal_error); + + TlsProtocol.WriteExtensions(buf, state.serverExtensions); + } + + return buf.ToArray(); + } + + protected virtual void NotifyClientCertificate(ServerHandshakeState state, Certificate clientCertificate) + { + if (state.certificateRequest == null) + throw new InvalidOperationException(); + + if (state.clientCertificate != null) + throw new TlsFatalAlert(AlertDescription.unexpected_message); + + state.clientCertificate = clientCertificate; + + if (clientCertificate.IsEmpty) + { + state.keyExchange.SkipClientCredentials(); + } + else + { + + /* + * TODO RFC 5246 7.4.6. If the certificate_authorities list in the certificate request + * message was non-empty, one of the certificates in the certificate chain SHOULD be + * issued by one of the listed CAs. + */ + + state.clientCertificateType = TlsUtilities.GetClientCertificateType(clientCertificate, + state.serverCredentials.Certificate); + + state.keyExchange.ProcessClientCertificate(clientCertificate); + } + + /* + * RFC 5246 7.4.6. If the client does not send any certificates, the server MAY at its + * discretion either continue the handshake without client authentication, or respond with a + * fatal handshake_failure alert. Also, if some aspect of the certificate chain was + * unacceptable (e.g., it was not signed by a known, trusted CA), the server MAY at its + * discretion either continue the handshake (considering the client unauthenticated) or send + * a fatal alert. + */ + state.server.NotifyClientCertificate(clientCertificate); + } + + protected virtual void ProcessClientCertificate(ServerHandshakeState state, byte[] body) + { + MemoryStream buf = new MemoryStream(body, false); + + Certificate clientCertificate = Certificate.Parse(buf); + + TlsProtocol.AssertEmpty(buf); + + NotifyClientCertificate(state, clientCertificate); + } + + protected virtual void ProcessCertificateVerify(ServerHandshakeState state, byte[] body, TlsHandshakeHash prepareFinishHash) + { + MemoryStream buf = new MemoryStream(body, false); + + TlsServerContextImpl context = state.serverContext; + DigitallySigned clientCertificateVerify = DigitallySigned.Parse(context, buf); + + TlsProtocol.AssertEmpty(buf); + + // Verify the CertificateVerify message contains a correct signature. + bool verified = false; + try + { + byte[] hash; + if (TlsUtilities.IsTlsV12(context)) + { + hash = prepareFinishHash.GetFinalHash(clientCertificateVerify.Algorithm.Hash); + } + else + { + hash = context.SecurityParameters.SessionHash; + } + + X509CertificateStructure x509Cert = state.clientCertificate.GetCertificateAt(0); + SubjectPublicKeyInfo keyInfo = x509Cert.SubjectPublicKeyInfo; + AsymmetricKeyParameter publicKey = PublicKeyFactory.CreateKey(keyInfo); + + TlsSigner tlsSigner = TlsUtilities.CreateTlsSigner((byte)state.clientCertificateType); + tlsSigner.Init(context); + verified = tlsSigner.VerifyRawSignature(clientCertificateVerify.Algorithm, + clientCertificateVerify.Signature, publicKey, hash); + } + catch (Exception) + { + } + + if (!verified) + throw new TlsFatalAlert(AlertDescription.decrypt_error); + } + + protected virtual void ProcessClientHello(ServerHandshakeState state, byte[] body) + { + MemoryStream buf = new MemoryStream(body, false); + + // TODO Read RFCs for guidance on the expected record layer version number + ProtocolVersion client_version = TlsUtilities.ReadVersion(buf); + if (!client_version.IsDtls) + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + + /* + * Read the client random + */ + byte[] client_random = TlsUtilities.ReadFully(32, buf); + + byte[] sessionID = TlsUtilities.ReadOpaque8(buf); + if (sessionID.Length > 32) + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + + // TODO RFC 4347 has the cookie length restricted to 32, but not in RFC 6347 + byte[] cookie = TlsUtilities.ReadOpaque8(buf); + + int cipher_suites_length = TlsUtilities.ReadUint16(buf); + if (cipher_suites_length < 2 || (cipher_suites_length & 1) != 0) + { + throw new TlsFatalAlert(AlertDescription.decode_error); + } + + /* + * NOTE: "If the session_id field is not empty (implying a session resumption request) this + * vector must include at least the cipher_suite from that session." + */ + state.offeredCipherSuites = TlsUtilities.ReadUint16Array(cipher_suites_length / 2, buf); + + int compression_methods_length = TlsUtilities.ReadUint8(buf); + if (compression_methods_length < 1) + { + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + } + + state.offeredCompressionMethods = TlsUtilities.ReadUint8Array(compression_methods_length, buf); + + /* + * TODO RFC 3546 2.3 If [...] the older session is resumed, then the server MUST ignore + * extensions appearing in the client hello, and send a server hello containing no + * extensions. + */ + state.clientExtensions = TlsProtocol.ReadExtensions(buf); + + TlsServerContextImpl context = state.serverContext; + SecurityParameters securityParameters = context.SecurityParameters; + + securityParameters.extendedMasterSecret = TlsExtensionsUtilities.HasExtendedMasterSecretExtension(state.clientExtensions); + + context.SetClientVersion(client_version); + + state.server.NotifyClientVersion(client_version); + state.server.NotifyFallback(Arrays.Contains(state.offeredCipherSuites, CipherSuite.TLS_FALLBACK_SCSV)); + + securityParameters.clientRandom = client_random; + + state.server.NotifyOfferedCipherSuites(state.offeredCipherSuites); + state.server.NotifyOfferedCompressionMethods(state.offeredCompressionMethods); + + /* + * RFC 5746 3.6. Server Behavior: Initial Handshake + */ + { + /* + * RFC 5746 3.4. The client MUST include either an empty "renegotiation_info" extension, + * or the TLS_EMPTY_RENEGOTIATION_INFO_SCSV signaling cipher suite value in the + * ClientHello. Including both is NOT RECOMMENDED. + */ + + /* + * When a ClientHello is received, the server MUST check if it includes the + * TLS_EMPTY_RENEGOTIATION_INFO_SCSV SCSV. If it does, set the secure_renegotiation flag + * to TRUE. + */ + if (Arrays.Contains(state.offeredCipherSuites, CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV)) + { + state.secure_renegotiation = true; + } + + /* + * The server MUST check if the "renegotiation_info" extension is included in the + * ClientHello. + */ + byte[] renegExtData = TlsUtilities.GetExtensionData(state.clientExtensions, ExtensionType.renegotiation_info); + if (renegExtData != null) + { + /* + * If the extension is present, set secure_renegotiation flag to TRUE. The + * server MUST then verify that the length of the "renegotiated_connection" + * field is zero, and if it is not, MUST abort the handshake. + */ + state.secure_renegotiation = true; + + if (!Arrays.ConstantTimeAreEqual(renegExtData, TlsProtocol.CreateRenegotiationInfo(TlsUtilities.EmptyBytes))) + throw new TlsFatalAlert(AlertDescription.handshake_failure); + } + } + + state.server.NotifySecureRenegotiation(state.secure_renegotiation); + + if (state.clientExtensions != null) + { + state.server.ProcessClientExtensions(state.clientExtensions); + } + } + + protected virtual void ProcessClientKeyExchange(ServerHandshakeState state, byte[] body) + { + MemoryStream buf = new MemoryStream(body, false); + + state.keyExchange.ProcessClientKeyExchange(buf); + + TlsProtocol.AssertEmpty(buf); + } + + protected virtual void ProcessClientSupplementalData(ServerHandshakeState state, byte[] body) + { + MemoryStream buf = new MemoryStream(body, false); + IList clientSupplementalData = TlsProtocol.ReadSupplementalDataMessage(buf); + state.server.ProcessClientSupplementalData(clientSupplementalData); + } + + protected virtual bool ExpectCertificateVerifyMessage(ServerHandshakeState state) + { + return state.clientCertificateType >= 0 && TlsUtilities.HasSigningCapability((byte)state.clientCertificateType); + } + + protected internal class ServerHandshakeState + { + internal TlsServer server = null; + internal TlsServerContextImpl serverContext = null; + internal int[] offeredCipherSuites; + internal byte[] offeredCompressionMethods; + internal IDictionary clientExtensions; + internal int selectedCipherSuite = -1; + internal short selectedCompressionMethod = -1; + internal bool secure_renegotiation = false; + internal short maxFragmentLength = -1; + internal bool allowCertificateStatus = false; + internal bool expectSessionTicket = false; + internal IDictionary serverExtensions = null; + internal TlsKeyExchange keyExchange = null; + internal TlsCredentials serverCredentials = null; + internal CertificateRequest certificateRequest = null; + internal short clientCertificateType = -1; + internal Certificate clientCertificate = null; + } + } +} diff --git a/crypto/src/crypto/tls/DtlsTransport.cs b/crypto/src/crypto/tls/DtlsTransport.cs new file mode 100644 index 000000000..5c607336b --- /dev/null +++ b/crypto/src/crypto/tls/DtlsTransport.cs @@ -0,0 +1,77 @@ +using System; +using System.IO; + +namespace Org.BouncyCastle.Crypto.Tls +{ + public class DtlsTransport + : DatagramTransport + { + private readonly DtlsRecordLayer mRecordLayer; + + internal DtlsTransport(DtlsRecordLayer recordLayer) + { + this.mRecordLayer = recordLayer; + } + + public virtual int GetReceiveLimit() + { + return mRecordLayer.GetReceiveLimit(); + } + + public virtual int GetSendLimit() + { + return mRecordLayer.GetSendLimit(); + } + + public virtual int Receive(byte[] buf, int off, int len, int waitMillis) + { + try + { + return mRecordLayer.Receive(buf, off, len, waitMillis); + } + catch (TlsFatalAlert fatalAlert) + { + mRecordLayer.Fail(fatalAlert.AlertDescription); + throw fatalAlert; + } + catch (IOException e) + { + mRecordLayer.Fail(AlertDescription.internal_error); + throw e; + } + catch (Exception e) + { + mRecordLayer.Fail(AlertDescription.internal_error); + throw new TlsFatalAlert(AlertDescription.internal_error, e); + } + } + + public virtual void Send(byte[] buf, int off, int len) + { + try + { + mRecordLayer.Send(buf, off, len); + } + catch (TlsFatalAlert fatalAlert) + { + mRecordLayer.Fail(fatalAlert.AlertDescription); + throw fatalAlert; + } + catch (IOException e) + { + mRecordLayer.Fail(AlertDescription.internal_error); + throw e; + } + catch (Exception e) + { + mRecordLayer.Fail(AlertDescription.internal_error); + throw new TlsFatalAlert(AlertDescription.internal_error, e); + } + } + + public virtual void Close() + { + mRecordLayer.Close(); + } + } +} |