summary refs log tree commit diff
path: root/crypto/src/tls/DtlsClientProtocol.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/tls/DtlsClientProtocol.cs')
-rw-r--r--crypto/src/tls/DtlsClientProtocol.cs145
1 files changed, 76 insertions, 69 deletions
diff --git a/crypto/src/tls/DtlsClientProtocol.cs b/crypto/src/tls/DtlsClientProtocol.cs
index e96e161d4..3f52d3c6b 100644
--- a/crypto/src/tls/DtlsClientProtocol.cs
+++ b/crypto/src/tls/DtlsClientProtocol.cs
@@ -23,17 +23,19 @@ namespace Org.BouncyCastle.Tls
             if (transport == null)
                 throw new ArgumentNullException("transport");
 
+            TlsClientContextImpl clientContext = new TlsClientContextImpl(client.Crypto);
+
             ClientHandshakeState state = new ClientHandshakeState();
             state.client = client;
-            state.clientContext = new TlsClientContextImpl(client.Crypto);
+            state.clientContext = clientContext;
 
-            client.Init(state.clientContext);
-            state.clientContext.HandshakeBeginning(client);
+            client.Init(clientContext);
+            clientContext.HandshakeBeginning(client);
 
-            SecurityParameters securityParameters = state.clientContext.SecurityParameters;
+            SecurityParameters securityParameters = clientContext.SecurityParameters;
             securityParameters.m_extendedPadding = client.ShouldUseExtendedPadding();
 
-            TlsSession sessionToResume = state.client.GetSessionToResume();
+            TlsSession sessionToResume = client.GetSessionToResume();
             if (sessionToResume != null && sessionToResume.IsResumable)
             {
                 SessionParameters sessionParameters = sessionToResume.ExportSessionParameters();
@@ -44,7 +46,7 @@ namespace Org.BouncyCastle.Tls
                  */
                 if (sessionParameters != null
                     && (sessionParameters.IsExtendedMasterSecret
-                        || (!state.client.RequiresExtendedMasterSecret() && state.client.AllowLegacyResumption())))
+                        || (!client.RequiresExtendedMasterSecret() && client.AllowLegacyResumption())))
                 {
                     TlsSecret masterSecret = sessionParameters.MasterSecret;
                     lock (masterSecret)
@@ -53,13 +55,13 @@ namespace Org.BouncyCastle.Tls
                         {
                             state.tlsSession = sessionToResume;
                             state.sessionParameters = sessionParameters;
-                            state.sessionMasterSecret = state.clientContext.Crypto.AdoptSecret(masterSecret);
+                            state.sessionMasterSecret = clientContext.Crypto.AdoptSecret(masterSecret);
                         }
                     }
                 }
             }
 
-            DtlsRecordLayer recordLayer = new DtlsRecordLayer(state.clientContext, state.client, transport);
+            DtlsRecordLayer recordLayer = new DtlsRecordLayer(clientContext, client, transport);
             client.NotifyCloseHandle(recordLayer);
 
             try
@@ -97,11 +99,12 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         internal virtual DtlsTransport ClientHandshake(ClientHandshakeState state, DtlsRecordLayer recordLayer)
         {
-            SecurityParameters securityParameters = state.clientContext.SecurityParameters;
+            TlsClient client = state.client;
+            TlsClientContextImpl clientContext = state.clientContext;
+            SecurityParameters securityParameters = clientContext.SecurityParameters;
 
-            DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.clientContext, recordLayer,
-                state.client.GetHandshakeTimeoutMillis(), TlsUtilities.GetHandshakeResendTimeMillis(state.client),
-                null);
+            DtlsReliableHandshake handshake = new DtlsReliableHandshake(clientContext, recordLayer,
+                client.GetHandshakeTimeoutMillis(), TlsUtilities.GetHandshakeResendTimeMillis(client), null);
 
             byte[] clientHelloBody = GenerateClientHello(state);
 
@@ -144,16 +147,16 @@ namespace Org.BouncyCastle.Tls
             if (securityParameters.IsResumedSession)
             {
                 securityParameters.m_masterSecret = state.sessionMasterSecret;
-                recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(state.clientContext));
+                recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(clientContext));
 
                 // NOTE: Calculated exclusive of the actual Finished message from the server
-                securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext,
+                securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(clientContext,
                     handshake.HandshakeHash, true);
                 ProcessFinished(handshake.ReceiveMessageBody(HandshakeType.finished),
                     securityParameters.PeerVerifyData);
 
                 // NOTE: Calculated exclusive of the Finished message itself
-                securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext,
+                securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(clientContext,
                     handshake.HandshakeHash, false);
                 handshake.SendMessage(HandshakeType.finished, securityParameters.LocalVerifyData);
 
@@ -169,12 +172,12 @@ namespace Org.BouncyCastle.Tls
                 securityParameters.m_pskIdentity = state.sessionParameters.PskIdentity;
                 securityParameters.m_srpIdentity = state.sessionParameters.SrpIdentity;
 
-                state.clientContext.HandshakeComplete(state.client, state.tlsSession);
+                clientContext.HandshakeComplete(client, state.tlsSession);
 
                 recordLayer.InitHeartbeat(state.heartbeat,
                     HeartbeatMode.peer_allowed_to_send == state.heartbeatPolicy);
 
-                return new DtlsTransport(recordLayer, state.client.IgnoreCorruptDtlsRecords);
+                return new DtlsTransport(recordLayer, client.IgnoreCorruptDtlsRecords);
             }
 
             InvalidateSession(state);
@@ -189,10 +192,10 @@ namespace Org.BouncyCastle.Tls
             }
             else
             {
-                state.client.ProcessServerSupplementalData(null);
+                client.ProcessServerSupplementalData(null);
             }
 
-            state.keyExchange = TlsUtilities.InitKeyExchangeClient(state.clientContext, state.client);
+            state.keyExchange = TlsUtilities.InitKeyExchangeClient(clientContext, client);
 
             if (serverMessage.Type == HandshakeType.certificate)
             {
@@ -218,7 +221,7 @@ namespace Org.BouncyCastle.Tls
                 // Okay, CertificateStatus is optional
             }
 
-            TlsUtilities.ProcessServerCertificate(state.clientContext, state.certificateStatus, state.keyExchange,
+            TlsUtilities.ProcessServerCertificate(clientContext, state.certificateStatus, state.keyExchange,
                 state.authentication, state.clientExtensions, state.serverExtensions);
 
             if (serverMessage.Type == HandshakeType.server_key_exchange)
@@ -308,7 +311,7 @@ namespace Org.BouncyCastle.Tls
                 state.keyExchange.ProcessClientCredentials(clientAuthCredentials);                    
             }
 
-            var clientSupplementalData = state.client.GetClientSupplementalData();
+            var clientSupplementalData = client.GetClientSupplementalData();
             if (clientSupplementalData != null)
             {
                 byte[] supplementalDataBody = GenerateSupplementalData(clientSupplementalData);
@@ -317,7 +320,7 @@ namespace Org.BouncyCastle.Tls
 
             if (null != state.certificateRequest)
             {
-                SendCertificateMessage(state.clientContext, handshake, clientAuthCertificate, null);
+                SendCertificateMessage(clientContext, handshake, clientAuthCertificate, null);
             }
 
             byte[] clientKeyExchangeBody = GenerateClientKeyExchange(state);
@@ -325,12 +328,12 @@ namespace Org.BouncyCastle.Tls
 
             securityParameters.m_sessionHash = TlsUtilities.GetCurrentPrfHash(handshake.HandshakeHash);
 
-            TlsProtocol.EstablishMasterSecret(state.clientContext, state.keyExchange);
-            recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(state.clientContext));
+            TlsProtocol.EstablishMasterSecret(clientContext, state.keyExchange);
+            recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(clientContext));
 
             if (clientAuthSigner != null)
             {
-                DigitallySigned certificateVerify = TlsUtilities.GenerateCertificateVerifyClient(state.clientContext,
+                DigitallySigned certificateVerify = TlsUtilities.GenerateCertificateVerifyClient(clientContext,
                     clientAuthSigner, clientAuthAlgorithm, clientAuthStreamSigner, handshake.HandshakeHash);
                 byte[] certificateVerifyBody = GenerateCertificateVerify(state, certificateVerify);
                 handshake.SendMessage(HandshakeType.certificate_verify, certificateVerifyBody);
@@ -338,7 +341,7 @@ namespace Org.BouncyCastle.Tls
 
             handshake.PrepareToFinish();
 
-            securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext,
+            securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(clientContext,
                 handshake.HandshakeHash, false);
             handshake.SendMessage(HandshakeType.finished, securityParameters.LocalVerifyData);
 
@@ -364,7 +367,7 @@ namespace Org.BouncyCastle.Tls
             }
 
             // NOTE: Calculated exclusive of the actual Finished message from the server
-            securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext,
+            securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(clientContext,
                 handshake.HandshakeHash, true);
             ProcessFinished(handshake.ReceiveMessageBody(HandshakeType.finished), securityParameters.PeerVerifyData);
 
@@ -376,7 +379,7 @@ namespace Org.BouncyCastle.Tls
                 .SetCipherSuite(securityParameters.CipherSuite)
                 .SetExtendedMasterSecret(securityParameters.IsExtendedMasterSecret)
                 .SetLocalCertificate(securityParameters.LocalCertificate)
-                .SetMasterSecret(state.clientContext.Crypto.AdoptSecret(state.sessionMasterSecret))
+                .SetMasterSecret(clientContext.Crypto.AdoptSecret(state.sessionMasterSecret))
                 .SetNegotiatedVersion(securityParameters.NegotiatedVersion)
                 .SetPeerCertificate(securityParameters.PeerCertificate)
                 .SetPskIdentity(securityParameters.PskIdentity)
@@ -389,11 +392,11 @@ namespace Org.BouncyCastle.Tls
 
             securityParameters.m_tlsUnique = securityParameters.LocalVerifyData;
 
-            state.clientContext.HandshakeComplete(state.client, state.tlsSession);
+            clientContext.HandshakeComplete(client, state.tlsSession);
 
             recordLayer.InitHeartbeat(state.heartbeat, HeartbeatMode.peer_allowed_to_send == state.heartbeatPolicy);
 
-            return new DtlsTransport(recordLayer, state.client.IgnoreCorruptDtlsRecords);
+            return new DtlsTransport(recordLayer, client.IgnoreCorruptDtlsRecords);
         }
 
         /// <exception cref="IOException"/>
@@ -408,29 +411,30 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         protected virtual byte[] GenerateClientHello(ClientHandshakeState state)
         {
-            TlsClientContextImpl context = state.clientContext;
-            SecurityParameters securityParameters = context.SecurityParameters;
+            TlsClient client = state.client;
+            TlsClientContextImpl clientContext = state.clientContext;
+            SecurityParameters securityParameters = clientContext.SecurityParameters;
 
-            context.SetClientSupportedVersions(state.client.GetProtocolVersions());
+            clientContext.SetClientSupportedVersions(client.GetProtocolVersions());
 
-            ProtocolVersion client_version = ProtocolVersion.GetLatestDtls(context.ClientSupportedVersions);
+            ProtocolVersion client_version = ProtocolVersion.GetLatestDtls(clientContext.ClientSupportedVersions);
             if (!ProtocolVersion.IsSupportedDtlsVersionClient(client_version))
                 throw new TlsFatalAlert(AlertDescription.internal_error);
 
-            context.SetClientVersion(client_version);
+            clientContext.SetClientVersion(client_version);
 
             {
                 bool useGmtUnixTime = ProtocolVersion.DTLSv12.IsEqualOrLaterVersionOf(client_version)
-                    && state.client.ShouldUseGmtUnixTime();
+                    && client.ShouldUseGmtUnixTime();
 
-                securityParameters.m_clientRandom = TlsProtocol.CreateRandomBlock(useGmtUnixTime, state.clientContext);
+                securityParameters.m_clientRandom = TlsProtocol.CreateRandomBlock(useGmtUnixTime, clientContext);
             }
 
             byte[] session_id = TlsUtilities.GetSessionID(state.tlsSession);
 
-            bool fallback = state.client.IsFallback();
+            bool fallback = client.IsFallback();
 
-            state.offeredCipherSuites = state.client.GetCipherSuites();
+            state.offeredCipherSuites = client.GetCipherSuites();
 
             if (session_id.Length > 0 && state.sessionParameters != null)
             {
@@ -440,8 +444,7 @@ namespace Org.BouncyCastle.Tls
                 }
             }
 
-            state.clientExtensions = TlsExtensionsUtilities.EnsureExtensionsInitialised(
-                state.client.GetClientExtensions());
+            state.clientExtensions = TlsExtensionsUtilities.EnsureExtensionsInitialised(client.GetClientExtensions());
 
             ProtocolVersion legacy_version = client_version;
             if (client_version.IsLaterVersionOf(ProtocolVersion.DTLSv12))
@@ -449,10 +452,10 @@ namespace Org.BouncyCastle.Tls
                 legacy_version = ProtocolVersion.DTLSv12;
 
                 TlsExtensionsUtilities.AddSupportedVersionsExtensionClient(state.clientExtensions,
-                    context.ClientSupportedVersions);
+                    clientContext.ClientSupportedVersions);
             }
 
-            context.SetRsaPreMasterSecretVersion(legacy_version);
+            clientContext.SetRsaPreMasterSecretVersion(legacy_version);
 
             securityParameters.m_clientServerNames = TlsExtensionsUtilities.GetServerNameExtensionClient(
                 state.clientExtensions);
@@ -465,16 +468,16 @@ namespace Org.BouncyCastle.Tls
             securityParameters.m_clientSupportedGroups = TlsExtensionsUtilities.GetSupportedGroupsExtension(
                 state.clientExtensions);
 
-            state.clientAgreements = TlsUtilities.AddKeyShareToClientHello(state.clientContext, state.client,
+            state.clientAgreements = TlsUtilities.AddKeyShareToClientHello(clientContext, client,
                 state.clientExtensions);
 
-            if (TlsUtilities.IsExtendedMasterSecretOptional(context.ClientSupportedVersions)
-                && state.client.ShouldUseExtendedMasterSecret())
+            if (TlsUtilities.IsExtendedMasterSecretOptional(clientContext.ClientSupportedVersions)
+                && client.ShouldUseExtendedMasterSecret())
             {
                 TlsExtensionsUtilities.AddExtendedMasterSecretExtension(state.clientExtensions);
             }
             else if (!TlsUtilities.IsTlsV13(client_version)
-                && state.client.RequiresExtendedMasterSecret())
+                && client.RequiresExtendedMasterSecret())
             {
                 throw new TlsFatalAlert(AlertDescription.internal_error);
             }
@@ -512,8 +515,8 @@ namespace Org.BouncyCastle.Tls
 
             // Heartbeats
             {
-                state.heartbeat = state.client.GetHeartbeat();
-                state.heartbeatPolicy = state.client.GetHeartbeatPolicy();
+                state.heartbeat = client.GetHeartbeat();
+                state.heartbeatPolicy = client.GetHeartbeatPolicy();
 
                 if (null != state.heartbeat || HeartbeatMode.peer_allowed_to_send == state.heartbeatPolicy)
                 {
@@ -528,7 +531,7 @@ namespace Org.BouncyCastle.Tls
                 cookie: TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions, 0);
 
             MemoryStream buf = new MemoryStream();
-            clientHello.Encode(state.clientContext, buf);
+            clientHello.Encode(clientContext, buf);
             return buf.ToArray();
         }
 
@@ -573,17 +576,19 @@ namespace Org.BouncyCastle.Tls
                 throw new TlsFatalAlert(AlertDescription.handshake_failure);
             }
 
+            TlsClientContextImpl clientContext = state.clientContext;
+            SecurityParameters securityParameters = clientContext.SecurityParameters;
+
             MemoryStream buf = new MemoryStream(body, false);
 
-            CertificateRequest certificateRequest = CertificateRequest.Parse(state.clientContext, buf);
+            CertificateRequest certificateRequest = CertificateRequest.Parse(clientContext, buf);
 
             TlsProtocol.AssertEmpty(buf);
 
             state.certificateRequest = TlsUtilities.ValidateCertificateRequest(certificateRequest, state.keyExchange);
 
-            state.clientContext.SecurityParameters.m_clientCertificateType =
-                TlsExtensionsUtilities.GetClientCertificateTypeExtensionServer(state.serverExtensions,
-                    CertificateType.X509);
+            securityParameters.m_clientCertificateType = TlsExtensionsUtilities.GetClientCertificateTypeExtensionServer(
+                state.serverExtensions, CertificateType.X509);
         }
 
         /// <exception cref="IOException"/>
@@ -644,6 +649,11 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         protected virtual void ProcessServerHello(ClientHandshakeState state, byte[] body)
         {
+            TlsClient client = state.client;
+            TlsClientContextImpl clientContext = state.clientContext;
+            SecurityParameters securityParameters = clientContext.SecurityParameters;
+
+
             MemoryStream buf = new MemoryStream(body, false);
 
             ServerHello serverHello = ServerHello.Parse(buf);
@@ -652,16 +662,13 @@ namespace Org.BouncyCastle.Tls
             state.serverExtensions = serverHello.Extensions;
 
 
-
-            SecurityParameters securityParameters = state.clientContext.SecurityParameters;
-
             // TODO[dtls13] Check supported_version extension for negotiated version
 
             ReportServerVersion(state, server_version);
 
             securityParameters.m_serverRandom = serverHello.Random;
 
-            if (!state.clientContext.ClientVersion.Equals(server_version))
+            if (!clientContext.ClientVersion.Equals(server_version))
             {
                 TlsUtilities.CheckDowngradeMarker(server_version, securityParameters.ServerRandom);
             }
@@ -669,7 +676,7 @@ namespace Org.BouncyCastle.Tls
             {
                 byte[] selectedSessionID = serverHello.SessionID;
                 securityParameters.m_sessionID = selectedSessionID;
-                state.client.NotifySessionID(selectedSessionID);
+                client.NotifySessionID(selectedSessionID);
                 securityParameters.m_resumedSession = selectedSessionID.Length > 0 && state.tlsSession != null
                     && Arrays.AreEqual(selectedSessionID, state.tlsSession.SessionID);
             }
@@ -689,7 +696,7 @@ namespace Org.BouncyCastle.Tls
                 }
 
                 TlsUtilities.NegotiatedCipherSuite(securityParameters, cipherSuite);
-                state.client.NotifySelectedCipherSuite(cipherSuite);
+                client.NotifySelectedCipherSuite(cipherSuite);
             }
 
             /*
@@ -726,13 +733,13 @@ namespace Org.BouncyCastle.Tls
 
                 if (acceptedExtendedMasterSecret)
                 {
-                    if (!securityParameters.IsResumedSession && !state.client.ShouldUseExtendedMasterSecret())
+                    if (!securityParameters.IsResumedSession && !client.ShouldUseExtendedMasterSecret())
                         throw new TlsFatalAlert(AlertDescription.handshake_failure);
                 }
                 else
                 {
-                    if (state.client.RequiresExtendedMasterSecret()
-                        || (securityParameters.IsResumedSession && !state.client.AllowLegacyResumption()))
+                    if (client.RequiresExtendedMasterSecret()
+                        || (securityParameters.IsResumedSession && !client.AllowLegacyResumption()))
                     {
                         throw new TlsFatalAlert(AlertDescription.handshake_failure);
                     }
@@ -815,7 +822,7 @@ namespace Org.BouncyCastle.Tls
             }
 
             // TODO[compat-gnutls] GnuTLS test server fails to send renegotiation_info extension when resuming
-            state.client.NotifySecureRenegotiation(securityParameters.IsSecureRenegotiation);
+            client.NotifySecureRenegotiation(securityParameters.IsSecureRenegotiation);
 
             /*
              * RFC 7301 3.1. When session resumption or session tickets [...] are used, the previous
@@ -920,7 +927,7 @@ namespace Org.BouncyCastle.Tls
 
             if (sessionClientExtensions != null)
             {
-                state.client.ProcessServerExtensions(sessionServerExtensions);
+                client.ProcessServerExtensions(sessionServerExtensions);
             }
         }
 
@@ -943,8 +950,8 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         protected virtual void ReportServerVersion(ClientHandshakeState state, ProtocolVersion server_version)
         {
-            TlsClientContextImpl context = state.clientContext;
-            SecurityParameters securityParameters = context.SecurityParameters;
+            TlsClientContextImpl clientContext = state.clientContext;
+            SecurityParameters securityParameters = clientContext.SecurityParameters;
 
             ProtocolVersion currentServerVersion = securityParameters.NegotiatedVersion;
             if (null != currentServerVersion)
@@ -955,12 +962,12 @@ namespace Org.BouncyCastle.Tls
                 return;
             }
 
-            if (!ProtocolVersion.Contains(context.ClientSupportedVersions, server_version))
+            if (!ProtocolVersion.Contains(clientContext.ClientSupportedVersions, server_version))
                 throw new TlsFatalAlert(AlertDescription.protocol_version);
 
             securityParameters.m_negotiatedVersion = server_version;
 
-            TlsUtilities.NegotiatedVersionDtlsClient(state.clientContext, state.client);
+            TlsUtilities.NegotiatedVersionDtlsClient(clientContext, state.client);
         }
 
         /// <exception cref="IOException"/>