summary refs log tree commit diff
path: root/crypto/src/tls/DtlsClientProtocol.cs
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2022-05-10 12:54:22 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2022-05-10 12:54:22 +0700
commit217c08cdb0359f95c40f1a09e4e545a4552509fe (patch)
treebb6418fed2a682e42ea77a82cd2da6f3e923d929 /crypto/src/tls/DtlsClientProtocol.cs
parentAvoid duplicate call (diff)
downloadBouncyCastle.NET-ed25519-217c08cdb0359f95c40f1a09e4e545a4552509fe.tar.xz
Improve TLS handshake hash tracking
Diffstat (limited to 'crypto/src/tls/DtlsClientProtocol.cs')
-rw-r--r--crypto/src/tls/DtlsClientProtocol.cs110
1 files changed, 61 insertions, 49 deletions
diff --git a/crypto/src/tls/DtlsClientProtocol.cs b/crypto/src/tls/DtlsClientProtocol.cs
index 44f574e3a..dd273f3e7 100644
--- a/crypto/src/tls/DtlsClientProtocol.cs
+++ b/crypto/src/tls/DtlsClientProtocol.cs
@@ -137,6 +137,10 @@ namespace Org.BouncyCastle.Tls
             }
 
             handshake.HandshakeHash.NotifyPrfDetermined();
+            if (!ProtocolVersion.DTLSv12.Equals(securityParameters.NegotiatedVersion))
+            {
+                handshake.HandshakeHash.SealHashAlgorithms();
+            }
 
             ApplyMaxFragmentLengthExtension(recordLayer, securityParameters.MaxFragmentLength);
 
@@ -237,12 +241,6 @@ namespace Org.BouncyCastle.Tls
 
                 TlsUtilities.EstablishServerSigAlgs(securityParameters, state.certificateRequest);
 
-                /*
-                 * TODO Give the client a chance to immediately select the CertificateVerify hash
-                 * algorithm here to avoid tracking the other hash algorithms unnecessarily?
-                 */
-                TlsUtilities.TrackHashAlgorithms(handshake.HandshakeHash, securityParameters.ServerSigAlgs);
-
                 serverMessage = handshake.ReceiveMessage();
             }
             else
@@ -262,54 +260,71 @@ namespace Org.BouncyCastle.Tls
                 throw new TlsFatalAlert(AlertDescription.unexpected_message);
             }
 
-            IList clientSupplementalData = state.client.GetClientSupplementalData();
-            if (clientSupplementalData != null)
-            {
-                byte[] supplementalDataBody = GenerateSupplementalData(clientSupplementalData);
-                handshake.SendMessage(HandshakeType.supplemental_data, supplementalDataBody);
-            }
+            TlsCredentials clientAuthCredentials = null;
+            TlsCredentialedSigner clientAuthSigner = null;
+            Certificate clientAuthCertificate = null;
+            SignatureAndHashAlgorithm clientAuthAlgorithm = null;
+            TlsStreamSigner clientAuthStreamSigner = null;
 
-            if (null != state.certificateRequest)
+            if (state.certificateRequest != null)
             {
-                state.clientCredentials = TlsUtilities.EstablishClientCredentials(state.authentication,
+                clientAuthCredentials = TlsUtilities.EstablishClientCredentials(state.authentication,
                     state.certificateRequest);
+                if (clientAuthCredentials != null)
+                {
+                    clientAuthCertificate = clientAuthCredentials.Certificate;
 
-                /*
-                 * RFC 5246 If no suitable certificate is available, the client MUST send a certificate
-                 * message containing no certificates.
-                 * 
-                 * NOTE: In previous RFCs, this was SHOULD instead of MUST.
-                 */
+                    if (clientAuthCredentials is TlsCredentialedSigner)
+                    {
+                        clientAuthSigner = (TlsCredentialedSigner)clientAuthCredentials;
+                        clientAuthAlgorithm = TlsUtilities.GetSignatureAndHashAlgorithm(
+                            securityParameters.NegotiatedVersion, clientAuthSigner);
+                        clientAuthStreamSigner = clientAuthSigner.GetStreamSigner();
 
-                Certificate clientCertificate = null;
-                if (null != state.clientCredentials)
-                {
-                    clientCertificate = state.clientCredentials.Certificate;
-                }
+                        if (ProtocolVersion.DTLSv12.Equals(securityParameters.NegotiatedVersion))
+                        {
+                            TlsUtilities.VerifySupportedSignatureAlgorithm(securityParameters.ServerSigAlgs,
+                                clientAuthAlgorithm, AlertDescription.internal_error);
 
-                SendCertificateMessage(state.clientContext, handshake, clientCertificate, null);
-            }
+                            if (clientAuthStreamSigner == null)
+                            {
+                                TlsUtilities.TrackHashAlgorithmClient(handshake.HandshakeHash, clientAuthAlgorithm);
+                            }
+                        }
 
-            TlsCredentialedSigner credentialedSigner = null;
-            TlsStreamSigner streamSigner = null;
+                        if (clientAuthStreamSigner != null)
+                        {
+                            handshake.HandshakeHash.ForceBuffering();
+                        }
+                    }
+                }
+            }
 
-            if (null != state.clientCredentials)
+            if (ProtocolVersion.DTLSv12.Equals(securityParameters.NegotiatedVersion))
             {
-                state.keyExchange.ProcessClientCredentials(state.clientCredentials);
+                handshake.HandshakeHash.SealHashAlgorithms();
+            }
 
-                if (state.clientCredentials is TlsCredentialedSigner)
-                {
-                    credentialedSigner = (TlsCredentialedSigner)state.clientCredentials;
-                    streamSigner = credentialedSigner.GetStreamSigner();
-                }
+            if (clientAuthCredentials == null)
+            {
+                state.keyExchange.SkipClientCredentials();
             }
             else
             {
-                state.keyExchange.SkipClientCredentials();
+                state.keyExchange.ProcessClientCredentials(clientAuthCredentials);                    
             }
 
-            bool forceBuffering = streamSigner != null;
-            TlsUtilities.SealHandshakeHash(state.clientContext, handshake.HandshakeHash, forceBuffering);
+            IList clientSupplementalData = state.client.GetClientSupplementalData();
+            if (clientSupplementalData != null)
+            {
+                byte[] supplementalDataBody = GenerateSupplementalData(clientSupplementalData);
+                handshake.SendMessage(HandshakeType.supplemental_data, supplementalDataBody);
+            }
+
+            if (null != state.certificateRequest)
+            {
+                SendCertificateMessage(state.clientContext, handshake, clientAuthCertificate, null);
+            }
 
             byte[] clientKeyExchangeBody = GenerateClientKeyExchange(state);
             handshake.SendMessage(HandshakeType.client_key_exchange, clientKeyExchangeBody);
@@ -319,18 +334,16 @@ namespace Org.BouncyCastle.Tls
             TlsProtocol.EstablishMasterSecret(state.clientContext, state.keyExchange);
             recordLayer.InitPendingEpoch(TlsUtilities.InitCipher(state.clientContext));
 
+            if (clientAuthSigner != null)
             {
-                if (credentialedSigner != null)
-                {
-                    DigitallySigned certificateVerify = TlsUtilities.GenerateCertificateVerifyClient(
-                        state.clientContext, credentialedSigner, streamSigner, handshake.HandshakeHash);
-                    byte[] certificateVerifyBody = GenerateCertificateVerify(state, certificateVerify);
-                    handshake.SendMessage(HandshakeType.certificate_verify, certificateVerifyBody);
-                }
-
-                handshake.PrepareToFinish();
+                DigitallySigned certificateVerify = TlsUtilities.GenerateCertificateVerifyClient(state.clientContext,
+                    clientAuthSigner, clientAuthAlgorithm, clientAuthStreamSigner, handshake.HandshakeHash);
+                byte[] certificateVerifyBody = GenerateCertificateVerify(state, certificateVerify);
+                handshake.SendMessage(HandshakeType.certificate_verify, certificateVerifyBody);
             }
 
+            handshake.PrepareToFinish();
+
             securityParameters.m_localVerifyData = TlsUtilities.CalculateVerifyData(state.clientContext,
                 handshake.HandshakeHash, false);
             handshake.SendMessage(HandshakeType.finished, securityParameters.LocalVerifyData);
@@ -973,7 +986,6 @@ namespace Org.BouncyCastle.Tls
             internal TlsAuthentication authentication = null;
             internal CertificateStatus certificateStatus = null;
             internal CertificateRequest certificateRequest = null;
-            internal TlsCredentials clientCredentials = null;
             internal TlsHeartbeat heartbeat = null;
             internal short heartbeatPolicy = HeartbeatMode.peer_not_allowed_to_send;
         }