diff options
author | Peter Dettman <peter.dettman@bouncycastle.org> | 2023-07-07 09:00:29 +0700 |
---|---|---|
committer | Peter Dettman <peter.dettman@bouncycastle.org> | 2023-07-07 09:00:29 +0700 |
commit | 8c5c106ef41adcfec646ba23762f59d6e8861e20 (patch) | |
tree | cff154b3a8298d9e3f928c2cc0c94f5cbcfc9b92 /crypto/src/tls/DtlsServerProtocol.cs | |
parent | (D)TLS: Refactoring around the MFL extension (diff) | |
download | BouncyCastle.NET-ed25519-8c5c106ef41adcfec646ba23762f59d6e8861e20.tar.xz |
Refactoring in DTLS
Diffstat (limited to 'crypto/src/tls/DtlsServerProtocol.cs')
-rw-r--r-- | crypto/src/tls/DtlsServerProtocol.cs | 133 |
1 files changed, 70 insertions, 63 deletions
diff --git a/crypto/src/tls/DtlsServerProtocol.cs b/crypto/src/tls/DtlsServerProtocol.cs index 66ab6d294..e3f2d7564 100644 --- a/crypto/src/tls/DtlsServerProtocol.cs +++ b/crypto/src/tls/DtlsServerProtocol.cs @@ -38,16 +38,19 @@ namespace Org.BouncyCastle.Tls if (transport == null) throw new ArgumentNullException("transport"); + TlsServerContextImpl serverContext = new TlsServerContextImpl(server.Crypto); + ServerHandshakeState state = new ServerHandshakeState(); state.server = server; - state.serverContext = new TlsServerContextImpl(server.Crypto); - server.Init(state.serverContext); - state.serverContext.HandshakeBeginning(server); + state.serverContext = serverContext; + + server.Init(serverContext); + serverContext.HandshakeBeginning(server); - SecurityParameters securityParameters = state.serverContext.SecurityParameters; + SecurityParameters securityParameters = serverContext.SecurityParameters; securityParameters.m_extendedPadding = server.ShouldUseExtendedPadding(); - DtlsRecordLayer recordLayer = new DtlsRecordLayer(state.serverContext, state.server, transport); + DtlsRecordLayer recordLayer = new DtlsRecordLayer(serverContext, server, transport); server.NotifyCloseHandle(recordLayer); try @@ -86,11 +89,12 @@ namespace Org.BouncyCastle.Tls internal virtual DtlsTransport ServerHandshake(ServerHandshakeState state, DtlsRecordLayer recordLayer, DtlsRequest request) { - SecurityParameters securityParameters = state.serverContext.SecurityParameters; + TlsServer server = state.server; + TlsServerContextImpl serverContext = state.serverContext; + SecurityParameters securityParameters = serverContext.SecurityParameters; - DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.serverContext, recordLayer, - state.server.GetHandshakeTimeoutMillis(), TlsUtilities.GetHandshakeResendTimeMillis(state.server), - request); + DtlsReliableHandshake handshake = new DtlsReliableHandshake(serverContext, recordLayer, + server.GetHandshakeTimeoutMillis(), TlsUtilities.GetHandshakeResendTimeMillis(server), request); DtlsReliableHandshake.Message clientMessage = null; @@ -132,14 +136,14 @@ namespace Org.BouncyCastle.Tls securityParameters.m_resumedSession = false; securityParameters.m_sessionID = state.tlsSession.SessionID; - state.server.NotifySession(state.tlsSession); + server.NotifySession(state.tlsSession); { byte[] serverHelloBody = GenerateServerHello(state, recordLayer); // TODO[dtls13] Ideally, move this into GenerateServerHello once legacy_record_version clarified { - ProtocolVersion recordLayerVersion = state.serverContext.ServerVersion; + ProtocolVersion recordLayerVersion = serverContext.ServerVersion; recordLayer.ReadVersion = recordLayerVersion; recordLayer.SetWriteVersion(recordLayerVersion); } @@ -149,20 +153,20 @@ namespace Org.BouncyCastle.Tls handshake.HandshakeHash.NotifyPrfDetermined(); - var serverSupplementalData = state.server.GetServerSupplementalData(); + var serverSupplementalData = server.GetServerSupplementalData(); if (serverSupplementalData != null) { byte[] supplementalDataBody = GenerateSupplementalData(serverSupplementalData); handshake.SendMessage(HandshakeType.supplemental_data, supplementalDataBody); } - state.keyExchange = TlsUtilities.InitKeyExchangeServer(state.serverContext, state.server); + state.keyExchange = TlsUtilities.InitKeyExchangeServer(serverContext, server); state.serverCredentials = null; if (!KeyExchangeAlgorithm.IsAnonymous(securityParameters.KeyExchangeAlgorithm)) { - state.serverCredentials = TlsUtilities.EstablishServerCredentials(state.server); + state.serverCredentials = TlsUtilities.EstablishServerCredentials(server); } // Server certificate @@ -180,7 +184,7 @@ namespace Org.BouncyCastle.Tls serverCertificate = state.serverCredentials.Certificate; - SendCertificateMessage(state.serverContext, handshake, serverCertificate, endPointHash); + SendCertificateMessage(serverContext, handshake, serverCertificate, endPointHash); } securityParameters.m_tlsServerEndPoint = endPointHash.ToArray(); @@ -193,7 +197,7 @@ namespace Org.BouncyCastle.Tls if (securityParameters.StatusRequestVersion > 0) { - CertificateStatus certificateStatus = state.server.GetCertificateStatus(); + CertificateStatus certificateStatus = server.GetCertificateStatus(); if (certificateStatus != null) { byte[] certificateStatusBody = GenerateCertificateStatus(state, certificateStatus); @@ -209,7 +213,7 @@ namespace Org.BouncyCastle.Tls if (state.serverCredentials != null) { - state.certificateRequest = state.server.GetCertificateRequest(); + state.certificateRequest = server.GetCertificateRequest(); if (null == state.certificateRequest) { @@ -223,7 +227,7 @@ namespace Org.BouncyCastle.Tls } else { - if (TlsUtilities.IsTlsV12(state.serverContext) + if (TlsUtilities.IsTlsV12(serverContext) != (state.certificateRequest.SupportedSignatureAlgorithms != null)) { throw new TlsFatalAlert(AlertDescription.internal_error); @@ -237,14 +241,14 @@ namespace Org.BouncyCastle.Tls { TlsUtilities.TrackHashAlgorithms(handshake.HandshakeHash, securityParameters.ServerSigAlgs); - if (state.serverContext.Crypto.HasAnyStreamVerifiers(securityParameters.ServerSigAlgs)) + if (serverContext.Crypto.HasAnyStreamVerifiers(securityParameters.ServerSigAlgs)) { handshake.HandshakeHash.ForceBuffering(); } } else { - if (state.serverContext.Crypto.HasAnyStreamVerifiersLegacy(state.certificateRequest.CertificateTypes)) + if (serverContext.Crypto.HasAnyStreamVerifiersLegacy(state.certificateRequest.CertificateTypes)) { handshake.HandshakeHash.ForceBuffering(); } @@ -271,7 +275,7 @@ namespace Org.BouncyCastle.Tls } else { - state.server.ProcessClientSupplementalData(null); + server.ProcessClientSupplementalData(null); } if (state.certificateRequest == null) @@ -287,7 +291,7 @@ namespace Org.BouncyCastle.Tls } else { - if (TlsUtilities.IsTlsV12(state.serverContext)) + if (TlsUtilities.IsTlsV12(serverContext)) { /* * RFC 5246 If no suitable certificate is available, the client MUST send a @@ -313,8 +317,8 @@ namespace Org.BouncyCastle.Tls securityParameters.m_sessionHash = TlsUtilities.GetCurrentPrfHash(handshake.HandshakeHash); - TlsProtocol.EstablishMasterSecret(state.serverContext, state.keyExchange); - recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(state.serverContext)); + TlsProtocol.EstablishMasterSecret(serverContext, state.keyExchange); + recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(serverContext)); /* * RFC 5246 7.4.8 This message is only sent following a client certificate that has signing @@ -337,7 +341,7 @@ namespace Org.BouncyCastle.Tls } // NOTE: Calculated exclusive of the actual Finished message from the client - securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(state.serverContext, + securityParameters.m_peerVerifyData = TlsUtilities.CalculateVerifyData(serverContext, handshake.HandshakeHash, false); ProcessFinished(handshake.ReceiveMessageBody(HandshakeType.finished), securityParameters.PeerVerifyData); @@ -348,13 +352,13 @@ namespace Org.BouncyCastle.Tls * is going to ignore any session ID it received once it sees the new_session_ticket message. */ - NewSessionTicket newSessionTicket = state.server.GetNewSessionTicket(); + NewSessionTicket newSessionTicket = server.GetNewSessionTicket(); byte[] newSessionTicketBody = GenerateNewSessionTicket(state, newSessionTicket); handshake.SendMessage(HandshakeType.new_session_ticket, newSessionTicketBody); } // NOTE: Calculated exclusive of the Finished message itself - securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(state.serverContext, + securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(serverContext, handshake.HandshakeHash, true); handshake.SendMessage(HandshakeType.finished, securityParameters.LocalVerifyData); @@ -366,7 +370,7 @@ namespace Org.BouncyCastle.Tls .SetCipherSuite(securityParameters.CipherSuite) .SetExtendedMasterSecret(securityParameters.IsExtendedMasterSecret) .SetLocalCertificate(securityParameters.LocalCertificate) - .SetMasterSecret(state.serverContext.Crypto.AdoptSecret(state.sessionMasterSecret)) + .SetMasterSecret(serverContext.Crypto.AdoptSecret(state.sessionMasterSecret)) .SetNegotiatedVersion(securityParameters.NegotiatedVersion) .SetPeerCertificate(securityParameters.PeerCertificate) .SetPskIdentity(securityParameters.PskIdentity) @@ -379,11 +383,11 @@ namespace Org.BouncyCastle.Tls securityParameters.m_tlsUnique = securityParameters.PeerVerifyData; - state.serverContext.HandshakeComplete(state.server, state.tlsSession); + serverContext.HandshakeComplete(server, state.tlsSession); recordLayer.InitHeartbeat(state.heartbeat, HeartbeatMode.peer_allowed_to_send == state.heartbeatPolicy); - return new DtlsTransport(recordLayer, state.server.IgnoreCorruptDtlsRecords); + return new DtlsTransport(recordLayer, server.IgnoreCorruptDtlsRecords); } /// <exception cref="IOException"/> @@ -417,12 +421,13 @@ namespace Org.BouncyCastle.Tls /// <exception cref="IOException"/> internal virtual byte[] GenerateServerHello(ServerHandshakeState state, DtlsRecordLayer recordLayer) { - TlsServerContextImpl context = state.serverContext; - SecurityParameters securityParameters = context.SecurityParameters; + TlsServer server = state.server; + TlsServerContextImpl serverContext = state.serverContext; + SecurityParameters securityParameters = serverContext.SecurityParameters; - ProtocolVersion server_version = state.server.GetServerVersion(); + ProtocolVersion server_version = server.GetServerVersion(); { - if (!ProtocolVersion.Contains(context.ClientSupportedVersions, server_version)) + if (!ProtocolVersion.Contains(serverContext.ClientSupportedVersions, server_version)) throw new TlsFatalAlert(AlertDescription.internal_error); // TODO[dtls13] Read draft/RFC for guidance on the legacy_record_version field @@ -433,16 +438,16 @@ namespace Org.BouncyCastle.Tls //recordLayer.SetWriteVersion(legacy_record_version); securityParameters.m_negotiatedVersion = server_version; - TlsUtilities.NegotiatedVersionDtlsServer(context); + TlsUtilities.NegotiatedVersionDtlsServer(serverContext); } { bool useGmtUnixTime = ProtocolVersion.DTLSv12.IsEqualOrLaterVersionOf(server_version) - && state.server.ShouldUseGmtUnixTime(); + && server.ShouldUseGmtUnixTime(); - securityParameters.m_serverRandom = TlsProtocol.CreateRandomBlock(useGmtUnixTime, context); + securityParameters.m_serverRandom = TlsProtocol.CreateRandomBlock(useGmtUnixTime, serverContext); - if (!server_version.Equals(ProtocolVersion.GetLatestDtls(state.server.GetProtocolVersions()))) + if (!server_version.Equals(ProtocolVersion.GetLatestDtls(server.GetProtocolVersions()))) { TlsUtilities.WriteDowngradeMarker(server_version, securityParameters.ServerRandom); } @@ -451,7 +456,7 @@ namespace Org.BouncyCastle.Tls bool resumedSession = securityParameters.IsResumedSession; { - int cipherSuite = ValidateSelectedCipherSuite(state.server.GetSelectedCipherSuite(), + int cipherSuite = ValidateSelectedCipherSuite(server.GetSelectedCipherSuite(), AlertDescription.internal_error); if (!TlsUtilities.IsValidCipherSuiteSelection(state.offeredCipherSuites, cipherSuite) || @@ -464,9 +469,9 @@ namespace Org.BouncyCastle.Tls } state.serverExtensions = TlsExtensionsUtilities.EnsureExtensionsInitialised( - state.server.GetServerExtensions()); + server.GetServerExtensions()); - state.server.GetServerExtensionsForConnection(state.serverExtensions); + server.GetServerExtensionsForConnection(state.serverExtensions); ProtocolVersion legacy_version = server_version; if (server_version.IsLaterVersionOf(ProtocolVersion.DTLSv12)) @@ -519,17 +524,17 @@ namespace Org.BouncyCastle.Tls else { securityParameters.m_extendedMasterSecret = state.offeredExtendedMasterSecret - && state.server.ShouldUseExtendedMasterSecret(); + && server.ShouldUseExtendedMasterSecret(); if (securityParameters.IsExtendedMasterSecret) { TlsExtensionsUtilities.AddExtendedMasterSecretExtension(state.serverExtensions); } - else if (state.server.RequiresExtendedMasterSecret()) + else if (server.RequiresExtendedMasterSecret()) { throw new TlsFatalAlert(AlertDescription.handshake_failure); } - else if (resumedSession && !state.server.AllowLegacyResumption()) + else if (resumedSession && !server.AllowLegacyResumption()) { throw new TlsFatalAlert(AlertDescription.internal_error); } @@ -619,7 +624,7 @@ namespace Org.BouncyCastle.Tls state.tlsSession.SessionID, securityParameters.CipherSuite, state.serverExtensions); MemoryStream buf = new MemoryStream(); - serverHello.Encode(state.serverContext, buf); + serverHello.Encode(serverContext, buf); return buf.ToArray(); } @@ -682,12 +687,13 @@ namespace Org.BouncyCastle.Tls MemoryStream buf = new MemoryStream(body, false); - TlsServerContextImpl context = state.serverContext; - DigitallySigned certificateVerify = DigitallySigned.Parse(context, buf); + TlsServerContextImpl serverContext = state.serverContext; + DigitallySigned certificateVerify = DigitallySigned.Parse(serverContext, buf); TlsProtocol.AssertEmpty(buf); - TlsUtilities.VerifyCertificateVerifyClient(context, state.certificateRequest, certificateVerify, handshakeHash); + TlsUtilities.VerifyCertificateVerifyClient(serverContext, state.certificateRequest, certificateVerify, + handshakeHash); } /// <exception cref="IOException"/> @@ -714,44 +720,45 @@ namespace Org.BouncyCastle.Tls - TlsServerContextImpl context = state.serverContext; - SecurityParameters securityParameters = context.SecurityParameters; + TlsServer server = state.server; + TlsServerContextImpl serverContext = state.serverContext; + SecurityParameters securityParameters = serverContext.SecurityParameters; if (!legacy_version.IsDtls) throw new TlsFatalAlert(AlertDescription.illegal_parameter); - context.SetRsaPreMasterSecretVersion(legacy_version); + serverContext.SetRsaPreMasterSecretVersion(legacy_version); - context.SetClientSupportedVersions( + serverContext.SetClientSupportedVersions( TlsExtensionsUtilities.GetSupportedVersionsExtensionClient(state.clientExtensions)); ProtocolVersion client_version = legacy_version; - if (null == context.ClientSupportedVersions) + if (null == serverContext.ClientSupportedVersions) { if (client_version.IsLaterVersionOf(ProtocolVersion.DTLSv12)) { client_version = ProtocolVersion.DTLSv12; } - context.SetClientSupportedVersions(client_version.DownTo(ProtocolVersion.DTLSv10)); + serverContext.SetClientSupportedVersions(client_version.DownTo(ProtocolVersion.DTLSv10)); } else { - client_version = ProtocolVersion.GetLatestDtls(context.ClientSupportedVersions); + client_version = ProtocolVersion.GetLatestDtls(serverContext.ClientSupportedVersions); } if (!ProtocolVersion.SERVER_EARLIEST_SUPPORTED_DTLS.IsEqualOrEarlierVersionOf(client_version)) throw new TlsFatalAlert(AlertDescription.protocol_version); - context.SetClientVersion(client_version); + serverContext.SetClientVersion(client_version); - state.server.NotifyClientVersion(context.ClientVersion); + server.NotifyClientVersion(serverContext.ClientVersion); securityParameters.m_clientRandom = clientHello.Random; - state.server.NotifyFallback(Arrays.Contains(state.offeredCipherSuites, CipherSuite.TLS_FALLBACK_SCSV)); + server.NotifyFallback(Arrays.Contains(state.offeredCipherSuites, CipherSuite.TLS_FALLBACK_SCSV)); - state.server.NotifyOfferedCipherSuites(state.offeredCipherSuites); + server.NotifyOfferedCipherSuites(state.offeredCipherSuites); /* * TODO[resumption] Check RFC 7627 5.4. for required behaviour @@ -800,7 +807,7 @@ namespace Org.BouncyCastle.Tls } } - state.server.NotifySecureRenegotiation(securityParameters.IsSecureRenegotiation); + server.NotifySecureRenegotiation(securityParameters.IsSecureRenegotiation); state.offeredExtendedMasterSecret = TlsExtensionsUtilities.HasExtendedMasterSecretExtension( state.clientExtensions); @@ -833,14 +840,14 @@ namespace Org.BouncyCastle.Tls { if (HeartbeatMode.peer_allowed_to_send == heartbeatExtension.Mode) { - state.heartbeat = state.server.GetHeartbeat(); + state.heartbeat = server.GetHeartbeat(); } - state.heartbeatPolicy = state.server.GetHeartbeatPolicy(); + state.heartbeatPolicy = server.GetHeartbeatPolicy(); } } - state.server.ProcessClientExtensions(state.clientExtensions); + server.ProcessClientExtensions(state.clientExtensions); } } |