summary refs log tree commit diff
path: root/crypto/src/tls/TlsClientProtocol.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/tls/TlsClientProtocol.cs')
-rw-r--r--crypto/src/tls/TlsClientProtocol.cs55
1 files changed, 36 insertions, 19 deletions
diff --git a/crypto/src/tls/TlsClientProtocol.cs b/crypto/src/tls/TlsClientProtocol.cs
index 870a898f8..021e8da1b 100644
--- a/crypto/src/tls/TlsClientProtocol.cs
+++ b/crypto/src/tls/TlsClientProtocol.cs
@@ -953,41 +953,58 @@ namespace Org.BouncyCastle.Tls
              */
             securityParameters.m_statusRequestVersion = m_clientExtensions.Contains(ExtensionType.status_request) ? 1 : 0;
 
-            // TODO[tls13-psk] Use PSK early secret if negotiated
             TlsSecret pskEarlySecret = null;
-
-            if (null != m_clientBinders)
             {
-                // TODO[tls13-psk] Process the server's pre_shared_key response, if any
-                //int selected_identity = TlsExtensionsUtilities.GetPreSharedKeyServerHello(extensions);
+                int selected_identity = TlsExtensionsUtilities.GetPreSharedKeyServerHello(extensions);
+                TlsPsk selectedPsk = null;
 
-                // TODO[tls13-psk] Notify client of selected PSK
-                // pskEarlySecret = ...;
+                if (selected_identity >= 0)
+                {
+                    if (null == m_clientBinders || selected_identity >= m_clientBinders.m_psks.Length)
+                        throw new TlsFatalAlert(AlertDescription.illegal_parameter);
 
-                this.m_clientBinders = null;
+                    selectedPsk = m_clientBinders.m_psks[selected_identity];
+                    if (selectedPsk.PrfAlgorithm != securityParameters.PrfAlgorithm)
+                        throw new TlsFatalAlert(AlertDescription.illegal_parameter);
+
+                    pskEarlySecret = m_clientBinders.m_earlySecrets[selected_identity];
+                }
+
+                m_tlsClient.NotifySelectedPsk(selectedPsk);
             }
 
             TlsSecret sharedSecret = null;
-
             {
                 KeyShareEntry keyShareEntry = TlsExtensionsUtilities.GetKeyShareServerHello(extensions);
                 if (null == keyShareEntry)
                 {
-                    // TODO[tls13-psk] This would be OK for PskKeyExchangeMode.psk_ke (and not after HRR)
-                    throw new TlsFatalAlert(AlertDescription.illegal_parameter);
+                    if (afterHelloRetryRequest
+                        || null == pskEarlySecret
+                        || !Arrays.Contains(m_clientBinders.m_pskKeyExchangeModes, PskKeyExchangeMode.psk_ke))
+                    {
+                        throw new TlsFatalAlert(AlertDescription.illegal_parameter);
+                    }
                 }
+                else
+                {
+                    if (null != pskEarlySecret
+                        && !Arrays.Contains(m_clientBinders.m_pskKeyExchangeModes, PskKeyExchangeMode.psk_dhe_ke))
+                    {
+                        throw new TlsFatalAlert(AlertDescription.illegal_parameter);
+                    }
 
-                if (!m_clientAgreements.Contains(keyShareEntry.NamedGroup))
-                    throw new TlsFatalAlert(AlertDescription.illegal_parameter);
-
-                TlsAgreement agreement = (TlsAgreement)m_clientAgreements[keyShareEntry.NamedGroup];
-
-                this.m_clientAgreements = null;
+                    TlsAgreement agreement = (TlsAgreement)m_clientAgreements[keyShareEntry.NamedGroup];
+                    if (null == agreement)
+                        throw new TlsFatalAlert(AlertDescription.illegal_parameter);
 
-                agreement.ReceivePeerValue(keyShareEntry.KeyExchange);
-                sharedSecret = agreement.CalculateSecret();
+                    agreement.ReceivePeerValue(keyShareEntry.KeyExchange);
+                    sharedSecret = agreement.CalculateSecret();
+                }
             }
 
+            this.m_clientAgreements = null;
+            this.m_clientBinders = null;
+
             TlsUtilities.Establish13PhaseSecrets(m_tlsClientContext, pskEarlySecret, sharedSecret);
 
             {