summary refs log tree commit diff
path: root/crypto/src/tls/DtlsServerProtocol.cs
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2023-07-07 09:00:29 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2023-07-07 09:00:29 +0700
commit8c5c106ef41adcfec646ba23762f59d6e8861e20 (patch)
treecff154b3a8298d9e3f928c2cc0c94f5cbcfc9b92 /crypto/src/tls/DtlsServerProtocol.cs
parent(D)TLS: Refactoring around the MFL extension (diff)
downloadBouncyCastle.NET-ed25519-8c5c106ef41adcfec646ba23762f59d6e8861e20.tar.xz
Refactoring in DTLS
Diffstat (limited to 'crypto/src/tls/DtlsServerProtocol.cs')
-rw-r--r--crypto/src/tls/DtlsServerProtocol.cs133
1 files changed, 70 insertions, 63 deletions
diff --git a/crypto/src/tls/DtlsServerProtocol.cs b/crypto/src/tls/DtlsServerProtocol.cs
index 66ab6d294..e3f2d7564 100644
--- a/crypto/src/tls/DtlsServerProtocol.cs
+++ b/crypto/src/tls/DtlsServerProtocol.cs
@@ -38,16 +38,19 @@ namespace Org.BouncyCastle.Tls
             if (transport == null)
                 throw new ArgumentNullException("transport");
 
+            TlsServerContextImpl serverContext = new TlsServerContextImpl(server.Crypto);
+
             ServerHandshakeState state = new ServerHandshakeState();
             state.server = server;
-            state.serverContext = new TlsServerContextImpl(server.Crypto);
-            server.Init(state.serverContext);
-            state.serverContext.HandshakeBeginning(server);
+            state.serverContext = serverContext;
+
+            server.Init(serverContext);
+            serverContext.HandshakeBeginning(server);
 
-            SecurityParameters securityParameters = state.serverContext.SecurityParameters;
+            SecurityParameters securityParameters = serverContext.SecurityParameters;
             securityParameters.m_extendedPadding = server.ShouldUseExtendedPadding();
 
-            DtlsRecordLayer recordLayer = new DtlsRecordLayer(state.serverContext, state.server, transport);
+            DtlsRecordLayer recordLayer = new DtlsRecordLayer(serverContext, server, transport);
             server.NotifyCloseHandle(recordLayer);
 
             try
@@ -86,11 +89,12 @@ namespace Org.BouncyCastle.Tls
         internal virtual DtlsTransport ServerHandshake(ServerHandshakeState state, DtlsRecordLayer recordLayer,
             DtlsRequest request)
         {
-            SecurityParameters securityParameters = state.serverContext.SecurityParameters;
+            TlsServer server = state.server;
+            TlsServerContextImpl serverContext = state.serverContext;
+            SecurityParameters securityParameters = serverContext.SecurityParameters;
 
-            DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.serverContext, recordLayer,
-                state.server.GetHandshakeTimeoutMillis(), TlsUtilities.GetHandshakeResendTimeMillis(state.server),
-                request);
+            DtlsReliableHandshake handshake = new DtlsReliableHandshake(serverContext, recordLayer,
+                server.GetHandshakeTimeoutMillis(), TlsUtilities.GetHandshakeResendTimeMillis(server), request);
 
             DtlsReliableHandshake.Message clientMessage = null;
 
@@ -132,14 +136,14 @@ namespace Org.BouncyCastle.Tls
             securityParameters.m_resumedSession = false;
             securityParameters.m_sessionID = state.tlsSession.SessionID;
 
-            state.server.NotifySession(state.tlsSession);
+            server.NotifySession(state.tlsSession);
 
             {
                 byte[] serverHelloBody = GenerateServerHello(state, recordLayer);
 
                 // TODO[dtls13] Ideally, move this into GenerateServerHello once legacy_record_version clarified
                 {
-                    ProtocolVersion recordLayerVersion = state.serverContext.ServerVersion;
+                    ProtocolVersion recordLayerVersion = serverContext.ServerVersion;
                     recordLayer.ReadVersion = recordLayerVersion;
                     recordLayer.SetWriteVersion(recordLayerVersion);
                 }
@@ -149,20 +153,20 @@ namespace Org.BouncyCastle.Tls
 
             handshake.HandshakeHash.NotifyPrfDetermined();
 
-            var serverSupplementalData = state.server.GetServerSupplementalData();
+            var serverSupplementalData = server.GetServerSupplementalData();
             if (serverSupplementalData != null)
             {
                 byte[] supplementalDataBody = GenerateSupplementalData(serverSupplementalData);
                 handshake.SendMessage(HandshakeType.supplemental_data, supplementalDataBody);
             }
 
-            state.keyExchange = TlsUtilities.InitKeyExchangeServer(state.serverContext, state.server);
+            state.keyExchange = TlsUtilities.InitKeyExchangeServer(serverContext, server);
 
             state.serverCredentials = null;
 
             if (!KeyExchangeAlgorithm.IsAnonymous(securityParameters.KeyExchangeAlgorithm))
             {
-                state.serverCredentials = TlsUtilities.EstablishServerCredentials(state.server);
+                state.serverCredentials = TlsUtilities.EstablishServerCredentials(server);
             }
 
             // Server certificate
@@ -180,7 +184,7 @@ namespace Org.BouncyCastle.Tls
 
                     serverCertificate = state.serverCredentials.Certificate;
 
-                    SendCertificateMessage(state.serverContext, handshake, serverCertificate, endPointHash);
+                    SendCertificateMessage(serverContext, handshake, serverCertificate, endPointHash);
                 }
                 securityParameters.m_tlsServerEndPoint = endPointHash.ToArray();
 
@@ -193,7 +197,7 @@ namespace Org.BouncyCastle.Tls
 
             if (securityParameters.StatusRequestVersion > 0)
             {
-                CertificateStatus certificateStatus = state.server.GetCertificateStatus();
+                CertificateStatus certificateStatus = server.GetCertificateStatus();
                 if (certificateStatus != null)
                 {
                     byte[] certificateStatusBody = GenerateCertificateStatus(state, certificateStatus);
@@ -209,7 +213,7 @@ namespace Org.BouncyCastle.Tls
 
             if (state.serverCredentials != null)
             {
-                state.certificateRequest = state.server.GetCertificateRequest();
+                state.certificateRequest = server.GetCertificateRequest();
 
                 if (null == state.certificateRequest)
                 {
@@ -223,7 +227,7 @@ namespace Org.BouncyCastle.Tls
                 }
                 else
                 {
-                    if (TlsUtilities.IsTlsV12(state.serverContext)
+                    if (TlsUtilities.IsTlsV12(serverContext)
                         != (state.certificateRequest.SupportedSignatureAlgorithms != null))
                     {
                         throw new TlsFatalAlert(AlertDescription.internal_error);
@@ -237,14 +241,14 @@ namespace Org.BouncyCastle.Tls
                     {
                         TlsUtilities.TrackHashAlgorithms(handshake.HandshakeHash, securityParameters.ServerSigAlgs);
 
-                        if (state.serverContext.Crypto.HasAnyStreamVerifiers(securityParameters.ServerSigAlgs))
+                        if (serverContext.Crypto.HasAnyStreamVerifiers(securityParameters.ServerSigAlgs))
                         {
                             handshake.HandshakeHash.ForceBuffering();
                         }
                     }
                     else
                     {
-                        if (state.serverContext.Crypto.HasAnyStreamVerifiersLegacy(state.certificateRequest.CertificateTypes))
+                        if (serverContext.Crypto.HasAnyStreamVerifiersLegacy(state.certificateRequest.CertificateTypes))
                         {
                             handshake.HandshakeHash.ForceBuffering();
                         }
@@ -271,7 +275,7 @@ namespace Org.BouncyCastle.Tls
             }
             else
             {
-                state.server.ProcessClientSupplementalData(null);
+                server.ProcessClientSupplementalData(null);
             }
 
             if (state.certificateRequest == null)
@@ -287,7 +291,7 @@ namespace Org.BouncyCastle.Tls
                 }
                 else
                 {
-                    if (TlsUtilities.IsTlsV12(state.serverContext))
+                    if (TlsUtilities.IsTlsV12(serverContext))
                     {
                         /*
                          * RFC 5246 If no suitable certificate is available, the client MUST send a
@@ -313,8 +317,8 @@ namespace Org.BouncyCastle.Tls
 
             securityParameters.m_sessionHash = TlsUtilities.GetCurrentPrfHash(handshake.HandshakeHash);
 
-            TlsProtocol.EstablishMasterSecret(state.serverContext, state.keyExchange);
-            recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(state.serverContext));
+            TlsProtocol.EstablishMasterSecret(serverContext, state.keyExchange);
+            recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(serverContext));
 
             /*
              * RFC 5246 7.4.8 This message is only sent following a client certificate that has signing
@@ -337,7 +341,7 @@ namespace Org.BouncyCastle.Tls
             }
 
             // NOTE: Calculated exclusive of the actual Finished message from the client
-            securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(state.serverContext,
+            securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(serverContext,
                 handshake.HandshakeHash, false);
             ProcessFinished(handshake.ReceiveMessageBody(HandshakeType.finished), securityParameters.PeerVerifyData);
 
@@ -348,13 +352,13 @@ namespace Org.BouncyCastle.Tls
                 * is going to ignore any session ID it received once it sees the new_session_ticket message.
                 */
 
-                NewSessionTicket newSessionTicket = state.server.GetNewSessionTicket();
+                NewSessionTicket newSessionTicket = server.GetNewSessionTicket();
                 byte[] newSessionTicketBody = GenerateNewSessionTicket(state, newSessionTicket);
                 handshake.SendMessage(HandshakeType.new_session_ticket, newSessionTicketBody);
             }
 
             // NOTE: Calculated exclusive of the Finished message itself
-            securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(state.serverContext,
+            securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(serverContext,
                 handshake.HandshakeHash, true);
             handshake.SendMessage(HandshakeType.finished, securityParameters.LocalVerifyData);
 
@@ -366,7 +370,7 @@ namespace Org.BouncyCastle.Tls
                 .SetCipherSuite(securityParameters.CipherSuite)
                 .SetExtendedMasterSecret(securityParameters.IsExtendedMasterSecret)
                 .SetLocalCertificate(securityParameters.LocalCertificate)
-                .SetMasterSecret(state.serverContext.Crypto.AdoptSecret(state.sessionMasterSecret))
+                .SetMasterSecret(serverContext.Crypto.AdoptSecret(state.sessionMasterSecret))
                 .SetNegotiatedVersion(securityParameters.NegotiatedVersion)
                 .SetPeerCertificate(securityParameters.PeerCertificate)
                 .SetPskIdentity(securityParameters.PskIdentity)
@@ -379,11 +383,11 @@ namespace Org.BouncyCastle.Tls
 
             securityParameters.m_tlsUnique = securityParameters.PeerVerifyData;
 
-            state.serverContext.HandshakeComplete(state.server, state.tlsSession);
+            serverContext.HandshakeComplete(server, state.tlsSession);
 
             recordLayer.InitHeartbeat(state.heartbeat, HeartbeatMode.peer_allowed_to_send == state.heartbeatPolicy);
 
-            return new DtlsTransport(recordLayer, state.server.IgnoreCorruptDtlsRecords);
+            return new DtlsTransport(recordLayer, server.IgnoreCorruptDtlsRecords);
         }
 
         /// <exception cref="IOException"/>
@@ -417,12 +421,13 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         internal virtual byte[] GenerateServerHello(ServerHandshakeState state, DtlsRecordLayer recordLayer)
         {
-            TlsServerContextImpl context = state.serverContext;
-            SecurityParameters securityParameters = context.SecurityParameters;
+            TlsServer server = state.server;
+            TlsServerContextImpl serverContext = state.serverContext;
+            SecurityParameters securityParameters = serverContext.SecurityParameters;
 
-            ProtocolVersion server_version = state.server.GetServerVersion();
+            ProtocolVersion server_version = server.GetServerVersion();
             {
-                if (!ProtocolVersion.Contains(context.ClientSupportedVersions, server_version))
+                if (!ProtocolVersion.Contains(serverContext.ClientSupportedVersions, server_version))
                     throw new TlsFatalAlert(AlertDescription.internal_error);
 
                 // TODO[dtls13] Read draft/RFC for guidance on the legacy_record_version field
@@ -433,16 +438,16 @@ namespace Org.BouncyCastle.Tls
                 //recordLayer.SetWriteVersion(legacy_record_version);
                 securityParameters.m_negotiatedVersion = server_version;
 
-                TlsUtilities.NegotiatedVersionDtlsServer(context);
+                TlsUtilities.NegotiatedVersionDtlsServer(serverContext);
             }
 
             {
                 bool useGmtUnixTime = ProtocolVersion.DTLSv12.IsEqualOrLaterVersionOf(server_version)
-                    && state.server.ShouldUseGmtUnixTime();
+                    && server.ShouldUseGmtUnixTime();
 
-                securityParameters.m_serverRandom = TlsProtocol.CreateRandomBlock(useGmtUnixTime, context);
+                securityParameters.m_serverRandom = TlsProtocol.CreateRandomBlock(useGmtUnixTime, serverContext);
 
-                if (!server_version.Equals(ProtocolVersion.GetLatestDtls(state.server.GetProtocolVersions())))
+                if (!server_version.Equals(ProtocolVersion.GetLatestDtls(server.GetProtocolVersions())))
                 {
                     TlsUtilities.WriteDowngradeMarker(server_version, securityParameters.ServerRandom);
                 }
@@ -451,7 +456,7 @@ namespace Org.BouncyCastle.Tls
             bool resumedSession = securityParameters.IsResumedSession;
 
             {
-                int cipherSuite = ValidateSelectedCipherSuite(state.server.GetSelectedCipherSuite(),
+                int cipherSuite = ValidateSelectedCipherSuite(server.GetSelectedCipherSuite(),
                     AlertDescription.internal_error);
 
                 if (!TlsUtilities.IsValidCipherSuiteSelection(state.offeredCipherSuites, cipherSuite) ||
@@ -464,9 +469,9 @@ namespace Org.BouncyCastle.Tls
             }
 
             state.serverExtensions = TlsExtensionsUtilities.EnsureExtensionsInitialised(
-                state.server.GetServerExtensions());
+                server.GetServerExtensions());
 
-            state.server.GetServerExtensionsForConnection(state.serverExtensions);
+            server.GetServerExtensionsForConnection(state.serverExtensions);
 
             ProtocolVersion legacy_version = server_version;
             if (server_version.IsLaterVersionOf(ProtocolVersion.DTLSv12))
@@ -519,17 +524,17 @@ namespace Org.BouncyCastle.Tls
             else
             {
                 securityParameters.m_extendedMasterSecret = state.offeredExtendedMasterSecret
-                    && state.server.ShouldUseExtendedMasterSecret();
+                    && server.ShouldUseExtendedMasterSecret();
 
                 if (securityParameters.IsExtendedMasterSecret)
                 {
                     TlsExtensionsUtilities.AddExtendedMasterSecretExtension(state.serverExtensions);
                 }
-                else if (state.server.RequiresExtendedMasterSecret())
+                else if (server.RequiresExtendedMasterSecret())
                 {
                     throw new TlsFatalAlert(AlertDescription.handshake_failure);
                 }
-                else if (resumedSession && !state.server.AllowLegacyResumption())
+                else if (resumedSession && !server.AllowLegacyResumption())
                 {
                     throw new TlsFatalAlert(AlertDescription.internal_error);
                 }
@@ -619,7 +624,7 @@ namespace Org.BouncyCastle.Tls
                 state.tlsSession.SessionID, securityParameters.CipherSuite, state.serverExtensions);
 
             MemoryStream buf = new MemoryStream();
-            serverHello.Encode(state.serverContext, buf);
+            serverHello.Encode(serverContext, buf);
             return buf.ToArray();
         }
 
@@ -682,12 +687,13 @@ namespace Org.BouncyCastle.Tls
 
             MemoryStream buf = new MemoryStream(body, false);
 
-            TlsServerContextImpl context = state.serverContext;
-            DigitallySigned certificateVerify = DigitallySigned.Parse(context, buf);
+            TlsServerContextImpl serverContext = state.serverContext;
+            DigitallySigned certificateVerify = DigitallySigned.Parse(serverContext, buf);
 
             TlsProtocol.AssertEmpty(buf);
 
-            TlsUtilities.VerifyCertificateVerifyClient(context, state.certificateRequest, certificateVerify, handshakeHash);
+            TlsUtilities.VerifyCertificateVerifyClient(serverContext, state.certificateRequest, certificateVerify,
+                handshakeHash);
         }
 
         /// <exception cref="IOException"/>
@@ -714,44 +720,45 @@ namespace Org.BouncyCastle.Tls
 
 
 
-            TlsServerContextImpl context = state.serverContext;
-            SecurityParameters securityParameters = context.SecurityParameters;
+            TlsServer server = state.server;
+            TlsServerContextImpl serverContext = state.serverContext;
+            SecurityParameters securityParameters = serverContext.SecurityParameters;
 
             if (!legacy_version.IsDtls)
                 throw new TlsFatalAlert(AlertDescription.illegal_parameter);
 
-            context.SetRsaPreMasterSecretVersion(legacy_version);
+            serverContext.SetRsaPreMasterSecretVersion(legacy_version);
 
-            context.SetClientSupportedVersions(
+            serverContext.SetClientSupportedVersions(
                 TlsExtensionsUtilities.GetSupportedVersionsExtensionClient(state.clientExtensions));
 
             ProtocolVersion client_version = legacy_version;
-            if (null == context.ClientSupportedVersions)
+            if (null == serverContext.ClientSupportedVersions)
             {
                 if (client_version.IsLaterVersionOf(ProtocolVersion.DTLSv12))
                 {
                     client_version = ProtocolVersion.DTLSv12;
                 }
 
-                context.SetClientSupportedVersions(client_version.DownTo(ProtocolVersion.DTLSv10));
+                serverContext.SetClientSupportedVersions(client_version.DownTo(ProtocolVersion.DTLSv10));
             }
             else
             {
-                client_version = ProtocolVersion.GetLatestDtls(context.ClientSupportedVersions);
+                client_version = ProtocolVersion.GetLatestDtls(serverContext.ClientSupportedVersions);
             }
 
             if (!ProtocolVersion.SERVER_EARLIEST_SUPPORTED_DTLS.IsEqualOrEarlierVersionOf(client_version))
                 throw new TlsFatalAlert(AlertDescription.protocol_version);
 
-            context.SetClientVersion(client_version);
+            serverContext.SetClientVersion(client_version);
 
-            state.server.NotifyClientVersion(context.ClientVersion);
+            server.NotifyClientVersion(serverContext.ClientVersion);
 
             securityParameters.m_clientRandom = clientHello.Random;
 
-            state.server.NotifyFallback(Arrays.Contains(state.offeredCipherSuites, CipherSuite.TLS_FALLBACK_SCSV));
+            server.NotifyFallback(Arrays.Contains(state.offeredCipherSuites, CipherSuite.TLS_FALLBACK_SCSV));
 
-            state.server.NotifyOfferedCipherSuites(state.offeredCipherSuites);
+            server.NotifyOfferedCipherSuites(state.offeredCipherSuites);
 
             /*
              * TODO[resumption] Check RFC 7627 5.4. for required behaviour 
@@ -800,7 +807,7 @@ namespace Org.BouncyCastle.Tls
                 }
             }
 
-            state.server.NotifySecureRenegotiation(securityParameters.IsSecureRenegotiation);
+            server.NotifySecureRenegotiation(securityParameters.IsSecureRenegotiation);
 
             state.offeredExtendedMasterSecret = TlsExtensionsUtilities.HasExtendedMasterSecretExtension(
                 state.clientExtensions);
@@ -833,14 +840,14 @@ namespace Org.BouncyCastle.Tls
                     {
                         if (HeartbeatMode.peer_allowed_to_send == heartbeatExtension.Mode)
                         {
-                            state.heartbeat = state.server.GetHeartbeat();
+                            state.heartbeat = server.GetHeartbeat();
                         }
 
-                        state.heartbeatPolicy = state.server.GetHeartbeatPolicy();
+                        state.heartbeatPolicy = server.GetHeartbeatPolicy();
                     }
                 }
 
-                state.server.ProcessClientExtensions(state.clientExtensions);
+                server.ProcessClientExtensions(state.clientExtensions);
             }
         }