summary refs log tree commit diff
path: root/crypto/src/tls/TlsServerProtocol.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/tls/TlsServerProtocol.cs')
-rw-r--r--crypto/src/tls/TlsServerProtocol.cs140
1 files changed, 89 insertions, 51 deletions
diff --git a/crypto/src/tls/TlsServerProtocol.cs b/crypto/src/tls/TlsServerProtocol.cs
index e14fb7d70..40218a2fb 100644
--- a/crypto/src/tls/TlsServerProtocol.cs
+++ b/crypto/src/tls/TlsServerProtocol.cs
@@ -121,7 +121,8 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        protected virtual ServerHello Generate13ServerHello(ClientHello clientHello, bool afterHelloRetryRequest)
+        protected virtual ServerHello Generate13ServerHello(ClientHello clientHello,
+            HandshakeMessageInput clientHelloMessage, bool afterHelloRetryRequest)
         {
             SecurityParameters securityParameters = m_tlsServerContext.SecurityParameters;
 
@@ -136,6 +137,10 @@ namespace Org.BouncyCastle.Tls
             ProtocolVersion serverVersion = securityParameters.NegotiatedVersion;
             TlsCrypto crypto = m_tlsServerContext.Crypto;
 
+            // NOTE: Will only select for psk_dhe_ke
+            OfferedPsks.SelectedConfig selectedPsk = TlsUtilities.SelectPreSharedKey(m_tlsServerContext, m_tlsServer,
+                clientHelloExtensions, clientHelloMessage, m_handshakeHash, afterHelloRetryRequest);
+
             IList clientShares = TlsExtensionsUtilities.GetKeyShareClientHello(clientHelloExtensions);
             KeyShareEntry clientShare = null;
 
@@ -144,6 +149,23 @@ namespace Org.BouncyCastle.Tls
                 if (m_retryGroup < 0)
                     throw new TlsFatalAlert(AlertDescription.internal_error);
 
+                if (null == selectedPsk)
+                {
+                    /*
+                     * RFC 8446 4.2.3. If a server is authenticating via a certificate and the client has
+                     * not sent a "signature_algorithms" extension, then the server MUST abort the handshake
+                     * with a "missing_extension" alert.
+                     */
+                    if (null == securityParameters.ClientSigAlgs)
+                        throw new TlsFatalAlert(AlertDescription.missing_extension);
+                }
+                else
+                {
+                    // TODO[tls13] Maybe filter the offered PSKs by PRF algorithm before server selection instead
+                    if (selectedPsk.m_psk.PrfAlgorithm != securityParameters.PrfAlgorithm)
+                        throw new TlsFatalAlert(AlertDescription.illegal_parameter);
+                }
+
                 /*
                  * TODO[tls13] Confirm fields in the ClientHello haven't changed
                  * 
@@ -182,8 +204,7 @@ namespace Org.BouncyCastle.Tls
                  * not sent a "signature_algorithms" extension, then the server MUST abort the handshake
                  * with a "missing_extension" alert.
                  */
-                // TODO[tls13] Revisit this check if we add support for PSK-only key exchange.
-                if (null == securityParameters.ClientSigAlgs)
+                if (null == selectedPsk && null == securityParameters.ClientSigAlgs)
                     throw new TlsFatalAlert(AlertDescription.missing_extension);
 
                 m_tlsServer.ProcessClientExtensions(clientHelloExtensions);
@@ -218,6 +239,7 @@ namespace Org.BouncyCastle.Tls
                 }
 
                 {
+                    // TODO[tls13] Constrain selection when PSK selected
                     int cipherSuite = m_tlsServer.GetSelectedCipherSuite();
 
                     if (!TlsUtilities.IsValidCipherSuiteSelection(m_offeredCipherSuites, cipherSuite) ||
@@ -309,11 +331,17 @@ namespace Org.BouncyCastle.Tls
 
             this.m_expectSessionTicket = false;
 
-            // TODO[tls13-psk] Use PSK early secret if negotiated
             TlsSecret pskEarlySecret = null;
+            if (null != selectedPsk)
+            {
+                pskEarlySecret = selectedPsk.m_earlySecret;
 
-            TlsSecret sharedSecret = null;
+                this.m_selectedPsk13 = true;
 
+                TlsExtensionsUtilities.AddPreSharedKeyServerHello(serverHelloExtensions, selectedPsk.m_index);
+            }
+
+            TlsSecret sharedSecret;
             {
                 int namedGroup = clientShare.NamedGroup;
 
@@ -353,7 +381,8 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        protected virtual ServerHello GenerateServerHello(ClientHello clientHello)
+        protected virtual ServerHello GenerateServerHello(ClientHello clientHello,
+            HandshakeMessageInput clientHelloMessage)
         {
             ProtocolVersion clientLegacyVersion = clientHello.Version;
             if (!clientLegacyVersion.IsTls)
@@ -426,7 +455,7 @@ namespace Org.BouncyCastle.Tls
 
                 m_recordStream.SetWriteVersion(ProtocolVersion.TLSv12);
 
-                return Generate13ServerHello(clientHello, false);
+                return Generate13ServerHello(clientHello, clientHelloMessage, false);
             }
 
             m_recordStream.SetWriteVersion(serverVersion);
@@ -747,10 +776,9 @@ namespace Org.BouncyCastle.Tls
                 case CS_SERVER_HELLO_RETRY_REQUEST:
                 {
                     ClientHello clientHelloRetry = ReceiveClientHelloMessage(buf);
-                    buf.UpdateHash(m_handshakeHash);
                     this.m_connectionState = CS_CLIENT_HELLO_RETRY;
 
-                    ServerHello serverHello = Generate13ServerHello(clientHelloRetry, true);
+                    ServerHello serverHello = Generate13ServerHello(clientHelloRetry, buf, true);
                     SendServerHelloMessage(serverHello);
                     this.m_connectionState = CS_SERVER_HELLO;
 
@@ -866,10 +894,9 @@ namespace Org.BouncyCastle.Tls
                 case CS_START:
                 {
                     ClientHello clientHello = ReceiveClientHelloMessage(buf);
-                    buf.UpdateHash(m_handshakeHash);
                     this.m_connectionState = CS_CLIENT_HELLO;
 
-                    ServerHello serverHello = GenerateServerHello(clientHello);
+                    ServerHello serverHello = GenerateServerHello(clientHello, buf);
                     m_handshakeHash.NotifyPrfDetermined();
 
                     if (TlsUtilities.IsTlsV13(securityParameters.NegotiatedVersion))
@@ -898,6 +925,9 @@ namespace Org.BouncyCastle.Tls
                         break;
                     }
 
+                    // For TLS 1.3+, this was already done by GenerateServerHello
+                    buf.UpdateHash(m_handshakeHash);
+
                     SendServerHelloMessage(serverHello);
                     this.m_connectionState = CS_SERVER_HELLO;
 
@@ -1042,9 +1072,6 @@ namespace Org.BouncyCastle.Tls
                         m_tlsServer.ProcessClientSupplementalData(null);
                     }
 
-                    if (m_certificateRequest == null)
-                        throw new TlsFatalAlert(AlertDescription.unexpected_message);
-
                     ReceiveCertificateMessage(buf);
                     this.m_connectionState = CS_CLIENT_CERTIFICATE;
                     break;
@@ -1232,6 +1259,9 @@ namespace Org.BouncyCastle.Tls
         {
             // TODO[tls13] This currently just duplicates 'receiveCertificateMessage'
 
+            if (null == m_certificateRequest)
+                throw new TlsFatalAlert(AlertDescription.unexpected_message);
+
             Certificate.ParseOptions options = new Certificate.ParseOptions()
                 .SetMaxChainLength(m_tlsServer.GetMaxCertificateChainLength());
 
@@ -1267,6 +1297,9 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         protected virtual void ReceiveCertificateMessage(MemoryStream buf)
         {
+            if (null == m_certificateRequest)
+                throw new TlsFatalAlert(AlertDescription.unexpected_message);
+
             Certificate.ParseOptions options = new Certificate.ParseOptions()
                 .SetMaxChainLength(m_tlsServer.GetMaxCertificateChainLength());
 
@@ -1353,52 +1386,57 @@ namespace Org.BouncyCastle.Tls
             Send13EncryptedExtensionsMessage(m_serverExtensions);
             this.m_connectionState = CS_SERVER_ENCRYPTED_EXTENSIONS;
 
-            // CertificateRequest
+            if (m_selectedPsk13)
+            {
+                /*
+                 * For PSK-only key exchange, there's no CertificateRequest, Certificate, CertificateVerify.
+                 */
+            }
+            else
             {
-                this.m_certificateRequest = m_tlsServer.GetCertificateRequest();
-                if (null != m_certificateRequest)
+                // CertificateRequest
                 {
-                    if (!m_certificateRequest.HasCertificateRequestContext(TlsUtilities.EmptyBytes))
-                        throw new TlsFatalAlert(AlertDescription.internal_error);
+                    this.m_certificateRequest = m_tlsServer.GetCertificateRequest();
+                    if (null != m_certificateRequest)
+                    {
+                        if (!m_certificateRequest.HasCertificateRequestContext(TlsUtilities.EmptyBytes))
+                            throw new TlsFatalAlert(AlertDescription.internal_error);
 
-                    TlsUtilities.EstablishServerSigAlgs(securityParameters, m_certificateRequest);
+                        TlsUtilities.EstablishServerSigAlgs(securityParameters, m_certificateRequest);
 
-                    SendCertificateRequestMessage(m_certificateRequest);
-                    this.m_connectionState = CS_SERVER_CERTIFICATE_REQUEST;
+                        SendCertificateRequestMessage(m_certificateRequest);
+                        this.m_connectionState = CS_SERVER_CERTIFICATE_REQUEST;
+                    }
                 }
-            }
-
-            /*
-             * TODO[tls13] For PSK-only key exchange, there's no Certificate message.
-             */
 
-            TlsCredentialedSigner serverCredentials = TlsUtilities.Establish13ServerCredentials(m_tlsServer);
-            if (null == serverCredentials)
-                throw new TlsFatalAlert(AlertDescription.internal_error);
+                TlsCredentialedSigner serverCredentials = TlsUtilities.Establish13ServerCredentials(m_tlsServer);
+                if (null == serverCredentials)
+                    throw new TlsFatalAlert(AlertDescription.internal_error);
 
-            // Certificate
-            {
-                /*
-                 * TODO[tls13] Note that we are expecting the TlsServer implementation to take care of
-                 * e.g. adding optional "status_request" extension to each CertificateEntry.
-                 */
-                /*
-                 * No CertificateStatus message is sent; TLS 1.3 uses per-CertificateEntry
-                 * "status_request" extension instead.
-                 */
+                // Certificate
+                {
+                    /*
+                     * TODO[tls13] Note that we are expecting the TlsServer implementation to take care of
+                     * e.g. adding optional "status_request" extension to each CertificateEntry.
+                     */
+                    /*
+                     * No CertificateStatus message is sent; TLS 1.3 uses per-CertificateEntry
+                     * "status_request" extension instead.
+                     */
 
-                Certificate serverCertificate = serverCredentials.Certificate;
-                Send13CertificateMessage(serverCertificate);
-                securityParameters.m_tlsServerEndPoint = null;
-                this.m_connectionState = CS_SERVER_CERTIFICATE;
-            }
+                    Certificate serverCertificate = serverCredentials.Certificate;
+                    Send13CertificateMessage(serverCertificate);
+                    securityParameters.m_tlsServerEndPoint = null;
+                    this.m_connectionState = CS_SERVER_CERTIFICATE;
+                }
 
-            // CertificateVerify
-            {
-                DigitallySigned certificateVerify = TlsUtilities.Generate13CertificateVerify(m_tlsServerContext,
-                    serverCredentials, m_handshakeHash);
-                Send13CertificateVerifyMessage(certificateVerify);
-                this.m_connectionState = CS_CLIENT_CERTIFICATE_VERIFY;
+                // CertificateVerify
+                {
+                    DigitallySigned certificateVerify = TlsUtilities.Generate13CertificateVerify(m_tlsServerContext,
+                        serverCredentials, m_handshakeHash);
+                    Send13CertificateVerifyMessage(certificateVerify);
+                    this.m_connectionState = CS_CLIENT_CERTIFICATE_VERIFY;
+                }
             }
 
             // Finished