diff options
Diffstat (limited to 'crypto/src/tls/DtlsClientProtocol.cs')
-rw-r--r-- | crypto/src/tls/DtlsClientProtocol.cs | 145 |
1 files changed, 76 insertions, 69 deletions
diff --git a/crypto/src/tls/DtlsClientProtocol.cs b/crypto/src/tls/DtlsClientProtocol.cs index e96e161d4..3f52d3c6b 100644 --- a/crypto/src/tls/DtlsClientProtocol.cs +++ b/crypto/src/tls/DtlsClientProtocol.cs @@ -23,17 +23,19 @@ namespace Org.BouncyCastle.Tls if (transport == null) throw new ArgumentNullException("transport"); + TlsClientContextImpl clientContext = new TlsClientContextImpl(client.Crypto); + ClientHandshakeState state = new ClientHandshakeState(); state.client = client; - state.clientContext = new TlsClientContextImpl(client.Crypto); + state.clientContext = clientContext; - client.Init(state.clientContext); - state.clientContext.HandshakeBeginning(client); + client.Init(clientContext); + clientContext.HandshakeBeginning(client); - SecurityParameters securityParameters = state.clientContext.SecurityParameters; + SecurityParameters securityParameters = clientContext.SecurityParameters; securityParameters.m_extendedPadding = client.ShouldUseExtendedPadding(); - TlsSession sessionToResume = state.client.GetSessionToResume(); + TlsSession sessionToResume = client.GetSessionToResume(); if (sessionToResume != null && sessionToResume.IsResumable) { SessionParameters sessionParameters = sessionToResume.ExportSessionParameters(); @@ -44,7 +46,7 @@ namespace Org.BouncyCastle.Tls */ if (sessionParameters != null && (sessionParameters.IsExtendedMasterSecret - || (!state.client.RequiresExtendedMasterSecret() && state.client.AllowLegacyResumption()))) + || (!client.RequiresExtendedMasterSecret() && client.AllowLegacyResumption()))) { TlsSecret masterSecret = sessionParameters.MasterSecret; lock (masterSecret) @@ -53,13 +55,13 @@ namespace Org.BouncyCastle.Tls { state.tlsSession = sessionToResume; state.sessionParameters = sessionParameters; - state.sessionMasterSecret = state.clientContext.Crypto.AdoptSecret(masterSecret); + state.sessionMasterSecret = clientContext.Crypto.AdoptSecret(masterSecret); } } } } - DtlsRecordLayer recordLayer = new DtlsRecordLayer(state.clientContext, state.client, transport); + DtlsRecordLayer recordLayer = new DtlsRecordLayer(clientContext, client, transport); client.NotifyCloseHandle(recordLayer); try @@ -97,11 +99,12 @@ namespace Org.BouncyCastle.Tls /// <exception cref="IOException"/> internal virtual DtlsTransport ClientHandshake(ClientHandshakeState state, DtlsRecordLayer recordLayer) { - SecurityParameters securityParameters = state.clientContext.SecurityParameters; + TlsClient client = state.client; + TlsClientContextImpl clientContext = state.clientContext; + SecurityParameters securityParameters = clientContext.SecurityParameters; - DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.clientContext, recordLayer, - state.client.GetHandshakeTimeoutMillis(), TlsUtilities.GetHandshakeResendTimeMillis(state.client), - null); + DtlsReliableHandshake handshake = new DtlsReliableHandshake(clientContext, recordLayer, + client.GetHandshakeTimeoutMillis(), TlsUtilities.GetHandshakeResendTimeMillis(client), null); byte[] clientHelloBody = GenerateClientHello(state); @@ -144,16 +147,16 @@ namespace Org.BouncyCastle.Tls if (securityParameters.IsResumedSession) { securityParameters.m_masterSecret = state.sessionMasterSecret; - recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(state.clientContext)); + recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(clientContext)); // NOTE: Calculated exclusive of the actual Finished message from the server - securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext, + securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(clientContext, handshake.HandshakeHash, true); ProcessFinished(handshake.ReceiveMessageBody(HandshakeType.finished), securityParameters.PeerVerifyData); // NOTE: Calculated exclusive of the Finished message itself - securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext, + securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(clientContext, handshake.HandshakeHash, false); handshake.SendMessage(HandshakeType.finished, securityParameters.LocalVerifyData); @@ -169,12 +172,12 @@ namespace Org.BouncyCastle.Tls securityParameters.m_pskIdentity = state.sessionParameters.PskIdentity; securityParameters.m_srpIdentity = state.sessionParameters.SrpIdentity; - state.clientContext.HandshakeComplete(state.client, state.tlsSession); + clientContext.HandshakeComplete(client, state.tlsSession); recordLayer.InitHeartbeat(state.heartbeat, HeartbeatMode.peer_allowed_to_send == state.heartbeatPolicy); - return new DtlsTransport(recordLayer, state.client.IgnoreCorruptDtlsRecords); + return new DtlsTransport(recordLayer, client.IgnoreCorruptDtlsRecords); } InvalidateSession(state); @@ -189,10 +192,10 @@ namespace Org.BouncyCastle.Tls } else { - state.client.ProcessServerSupplementalData(null); + client.ProcessServerSupplementalData(null); } - state.keyExchange = TlsUtilities.InitKeyExchangeClient(state.clientContext, state.client); + state.keyExchange = TlsUtilities.InitKeyExchangeClient(clientContext, client); if (serverMessage.Type == HandshakeType.certificate) { @@ -218,7 +221,7 @@ namespace Org.BouncyCastle.Tls // Okay, CertificateStatus is optional } - TlsUtilities.ProcessServerCertificate(state.clientContext, state.certificateStatus, state.keyExchange, + TlsUtilities.ProcessServerCertificate(clientContext, state.certificateStatus, state.keyExchange, state.authentication, state.clientExtensions, state.serverExtensions); if (serverMessage.Type == HandshakeType.server_key_exchange) @@ -308,7 +311,7 @@ namespace Org.BouncyCastle.Tls state.keyExchange.ProcessClientCredentials(clientAuthCredentials); } - var clientSupplementalData = state.client.GetClientSupplementalData(); + var clientSupplementalData = client.GetClientSupplementalData(); if (clientSupplementalData != null) { byte[] supplementalDataBody = GenerateSupplementalData(clientSupplementalData); @@ -317,7 +320,7 @@ namespace Org.BouncyCastle.Tls if (null != state.certificateRequest) { - SendCertificateMessage(state.clientContext, handshake, clientAuthCertificate, null); + SendCertificateMessage(clientContext, handshake, clientAuthCertificate, null); } byte[] clientKeyExchangeBody = GenerateClientKeyExchange(state); @@ -325,12 +328,12 @@ namespace Org.BouncyCastle.Tls securityParameters.m_sessionHash = TlsUtilities.GetCurrentPrfHash(handshake.HandshakeHash); - TlsProtocol.EstablishMasterSecret(state.clientContext, state.keyExchange); - recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(state.clientContext)); + TlsProtocol.EstablishMasterSecret(clientContext, state.keyExchange); + recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(clientContext)); if (clientAuthSigner != null) { - DigitallySigned certificateVerify = TlsUtilities.GenerateCertificateVerifyClient(state.clientContext, + DigitallySigned certificateVerify = TlsUtilities.GenerateCertificateVerifyClient(clientContext, clientAuthSigner, clientAuthAlgorithm, clientAuthStreamSigner, handshake.HandshakeHash); byte[] certificateVerifyBody = GenerateCertificateVerify(state, certificateVerify); handshake.SendMessage(HandshakeType.certificate_verify, certificateVerifyBody); @@ -338,7 +341,7 @@ namespace Org.BouncyCastle.Tls handshake.PrepareToFinish(); - securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext, + securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(clientContext, handshake.HandshakeHash, false); handshake.SendMessage(HandshakeType.finished, securityParameters.LocalVerifyData); @@ -364,7 +367,7 @@ namespace Org.BouncyCastle.Tls } // NOTE: Calculated exclusive of the actual Finished message from the server - securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext, + securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(clientContext, handshake.HandshakeHash, true); ProcessFinished(handshake.ReceiveMessageBody(HandshakeType.finished), securityParameters.PeerVerifyData); @@ -376,7 +379,7 @@ namespace Org.BouncyCastle.Tls .SetCipherSuite(securityParameters.CipherSuite) .SetExtendedMasterSecret(securityParameters.IsExtendedMasterSecret) .SetLocalCertificate(securityParameters.LocalCertificate) - .SetMasterSecret(state.clientContext.Crypto.AdoptSecret(state.sessionMasterSecret)) + .SetMasterSecret(clientContext.Crypto.AdoptSecret(state.sessionMasterSecret)) .SetNegotiatedVersion(securityParameters.NegotiatedVersion) .SetPeerCertificate(securityParameters.PeerCertificate) .SetPskIdentity(securityParameters.PskIdentity) @@ -389,11 +392,11 @@ namespace Org.BouncyCastle.Tls securityParameters.m_tlsUnique = securityParameters.LocalVerifyData; - state.clientContext.HandshakeComplete(state.client, state.tlsSession); + clientContext.HandshakeComplete(client, state.tlsSession); recordLayer.InitHeartbeat(state.heartbeat, HeartbeatMode.peer_allowed_to_send == state.heartbeatPolicy); - return new DtlsTransport(recordLayer, state.client.IgnoreCorruptDtlsRecords); + return new DtlsTransport(recordLayer, client.IgnoreCorruptDtlsRecords); } /// <exception cref="IOException"/> @@ -408,29 +411,30 @@ namespace Org.BouncyCastle.Tls /// <exception cref="IOException"/> protected virtual byte[] GenerateClientHello(ClientHandshakeState state) { - TlsClientContextImpl context = state.clientContext; - SecurityParameters securityParameters = context.SecurityParameters; + TlsClient client = state.client; + TlsClientContextImpl clientContext = state.clientContext; + SecurityParameters securityParameters = clientContext.SecurityParameters; - context.SetClientSupportedVersions(state.client.GetProtocolVersions()); + clientContext.SetClientSupportedVersions(client.GetProtocolVersions()); - ProtocolVersion client_version = ProtocolVersion.GetLatestDtls(context.ClientSupportedVersions); + ProtocolVersion client_version = ProtocolVersion.GetLatestDtls(clientContext.ClientSupportedVersions); if (!ProtocolVersion.IsSupportedDtlsVersionClient(client_version)) throw new TlsFatalAlert(AlertDescription.internal_error); - context.SetClientVersion(client_version); + clientContext.SetClientVersion(client_version); { bool useGmtUnixTime = ProtocolVersion.DTLSv12.IsEqualOrLaterVersionOf(client_version) - && state.client.ShouldUseGmtUnixTime(); + && client.ShouldUseGmtUnixTime(); - securityParameters.m_clientRandom = TlsProtocol.CreateRandomBlock(useGmtUnixTime, state.clientContext); + securityParameters.m_clientRandom = TlsProtocol.CreateRandomBlock(useGmtUnixTime, clientContext); } byte[] session_id = TlsUtilities.GetSessionID(state.tlsSession); - bool fallback = state.client.IsFallback(); + bool fallback = client.IsFallback(); - state.offeredCipherSuites = state.client.GetCipherSuites(); + state.offeredCipherSuites = client.GetCipherSuites(); if (session_id.Length > 0 && state.sessionParameters != null) { @@ -440,8 +444,7 @@ namespace Org.BouncyCastle.Tls } } - state.clientExtensions = TlsExtensionsUtilities.EnsureExtensionsInitialised( - state.client.GetClientExtensions()); + state.clientExtensions = TlsExtensionsUtilities.EnsureExtensionsInitialised(client.GetClientExtensions()); ProtocolVersion legacy_version = client_version; if (client_version.IsLaterVersionOf(ProtocolVersion.DTLSv12)) @@ -449,10 +452,10 @@ namespace Org.BouncyCastle.Tls legacy_version = ProtocolVersion.DTLSv12; TlsExtensionsUtilities.AddSupportedVersionsExtensionClient(state.clientExtensions, - context.ClientSupportedVersions); + clientContext.ClientSupportedVersions); } - context.SetRsaPreMasterSecretVersion(legacy_version); + clientContext.SetRsaPreMasterSecretVersion(legacy_version); securityParameters.m_clientServerNames = TlsExtensionsUtilities.GetServerNameExtensionClient( state.clientExtensions); @@ -465,16 +468,16 @@ namespace Org.BouncyCastle.Tls securityParameters.m_clientSupportedGroups = TlsExtensionsUtilities.GetSupportedGroupsExtension( state.clientExtensions); - state.clientAgreements = TlsUtilities.AddKeyShareToClientHello(state.clientContext, state.client, + state.clientAgreements = TlsUtilities.AddKeyShareToClientHello(clientContext, client, state.clientExtensions); - if (TlsUtilities.IsExtendedMasterSecretOptional(context.ClientSupportedVersions) - && state.client.ShouldUseExtendedMasterSecret()) + if (TlsUtilities.IsExtendedMasterSecretOptional(clientContext.ClientSupportedVersions) + && client.ShouldUseExtendedMasterSecret()) { TlsExtensionsUtilities.AddExtendedMasterSecretExtension(state.clientExtensions); } else if (!TlsUtilities.IsTlsV13(client_version) - && state.client.RequiresExtendedMasterSecret()) + && client.RequiresExtendedMasterSecret()) { throw new TlsFatalAlert(AlertDescription.internal_error); } @@ -512,8 +515,8 @@ namespace Org.BouncyCastle.Tls // Heartbeats { - state.heartbeat = state.client.GetHeartbeat(); - state.heartbeatPolicy = state.client.GetHeartbeatPolicy(); + state.heartbeat = client.GetHeartbeat(); + state.heartbeatPolicy = client.GetHeartbeatPolicy(); if (null != state.heartbeat || HeartbeatMode.peer_allowed_to_send == state.heartbeatPolicy) { @@ -528,7 +531,7 @@ namespace Org.BouncyCastle.Tls cookie: TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions, 0); MemoryStream buf = new MemoryStream(); - clientHello.Encode(state.clientContext, buf); + clientHello.Encode(clientContext, buf); return buf.ToArray(); } @@ -573,17 +576,19 @@ namespace Org.BouncyCastle.Tls throw new TlsFatalAlert(AlertDescription.handshake_failure); } + TlsClientContextImpl clientContext = state.clientContext; + SecurityParameters securityParameters = clientContext.SecurityParameters; + MemoryStream buf = new MemoryStream(body, false); - CertificateRequest certificateRequest = CertificateRequest.Parse(state.clientContext, buf); + CertificateRequest certificateRequest = CertificateRequest.Parse(clientContext, buf); TlsProtocol.AssertEmpty(buf); state.certificateRequest = TlsUtilities.ValidateCertificateRequest(certificateRequest, state.keyExchange); - state.clientContext.SecurityParameters.m_clientCertificateType = - TlsExtensionsUtilities.GetClientCertificateTypeExtensionServer(state.serverExtensions, - CertificateType.X509); + securityParameters.m_clientCertificateType = TlsExtensionsUtilities.GetClientCertificateTypeExtensionServer( + state.serverExtensions, CertificateType.X509); } /// <exception cref="IOException"/> @@ -644,6 +649,11 @@ namespace Org.BouncyCastle.Tls /// <exception cref="IOException"/> protected virtual void ProcessServerHello(ClientHandshakeState state, byte[] body) { + TlsClient client = state.client; + TlsClientContextImpl clientContext = state.clientContext; + SecurityParameters securityParameters = clientContext.SecurityParameters; + + MemoryStream buf = new MemoryStream(body, false); ServerHello serverHello = ServerHello.Parse(buf); @@ -652,16 +662,13 @@ namespace Org.BouncyCastle.Tls state.serverExtensions = serverHello.Extensions; - - SecurityParameters securityParameters = state.clientContext.SecurityParameters; - // TODO[dtls13] Check supported_version extension for negotiated version ReportServerVersion(state, server_version); securityParameters.m_serverRandom = serverHello.Random; - if (!state.clientContext.ClientVersion.Equals(server_version)) + if (!clientContext.ClientVersion.Equals(server_version)) { TlsUtilities.CheckDowngradeMarker(server_version, securityParameters.ServerRandom); } @@ -669,7 +676,7 @@ namespace Org.BouncyCastle.Tls { byte[] selectedSessionID = serverHello.SessionID; securityParameters.m_sessionID = selectedSessionID; - state.client.NotifySessionID(selectedSessionID); + client.NotifySessionID(selectedSessionID); securityParameters.m_resumedSession = selectedSessionID.Length > 0 && state.tlsSession != null && Arrays.AreEqual(selectedSessionID, state.tlsSession.SessionID); } @@ -689,7 +696,7 @@ namespace Org.BouncyCastle.Tls } TlsUtilities.NegotiatedCipherSuite(securityParameters, cipherSuite); - state.client.NotifySelectedCipherSuite(cipherSuite); + client.NotifySelectedCipherSuite(cipherSuite); } /* @@ -726,13 +733,13 @@ namespace Org.BouncyCastle.Tls if (acceptedExtendedMasterSecret) { - if (!securityParameters.IsResumedSession && !state.client.ShouldUseExtendedMasterSecret()) + if (!securityParameters.IsResumedSession && !client.ShouldUseExtendedMasterSecret()) throw new TlsFatalAlert(AlertDescription.handshake_failure); } else { - if (state.client.RequiresExtendedMasterSecret() - || (securityParameters.IsResumedSession && !state.client.AllowLegacyResumption())) + if (client.RequiresExtendedMasterSecret() + || (securityParameters.IsResumedSession && !client.AllowLegacyResumption())) { throw new TlsFatalAlert(AlertDescription.handshake_failure); } @@ -815,7 +822,7 @@ namespace Org.BouncyCastle.Tls } // TODO[compat-gnutls] GnuTLS test server fails to send renegotiation_info extension when resuming - state.client.NotifySecureRenegotiation(securityParameters.IsSecureRenegotiation); + client.NotifySecureRenegotiation(securityParameters.IsSecureRenegotiation); /* * RFC 7301 3.1. When session resumption or session tickets [...] are used, the previous @@ -920,7 +927,7 @@ namespace Org.BouncyCastle.Tls if (sessionClientExtensions != null) { - state.client.ProcessServerExtensions(sessionServerExtensions); + client.ProcessServerExtensions(sessionServerExtensions); } } @@ -943,8 +950,8 @@ namespace Org.BouncyCastle.Tls /// <exception cref="IOException"/> protected virtual void ReportServerVersion(ClientHandshakeState state, ProtocolVersion server_version) { - TlsClientContextImpl context = state.clientContext; - SecurityParameters securityParameters = context.SecurityParameters; + TlsClientContextImpl clientContext = state.clientContext; + SecurityParameters securityParameters = clientContext.SecurityParameters; ProtocolVersion currentServerVersion = securityParameters.NegotiatedVersion; if (null != currentServerVersion) @@ -955,12 +962,12 @@ namespace Org.BouncyCastle.Tls return; } - if (!ProtocolVersion.Contains(context.ClientSupportedVersions, server_version)) + if (!ProtocolVersion.Contains(clientContext.ClientSupportedVersions, server_version)) throw new TlsFatalAlert(AlertDescription.protocol_version); securityParameters.m_negotiatedVersion = server_version; - TlsUtilities.NegotiatedVersionDtlsClient(state.clientContext, state.client); + TlsUtilities.NegotiatedVersionDtlsClient(clientContext, state.client); } /// <exception cref="IOException"/> |