summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2021-10-17 21:46:19 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2021-10-17 21:46:19 +0700
commitbd2fe5262f97293908320481e0eeefb0a92b364c (patch)
treef598e7fd9efd794f5abff6bd5345941b7d9be8ea
parentTLS 1.3 PSK server-side work (diff)
downloadBouncyCastle.NET-ed25519-bd2fe5262f97293908320481e0eeefb0a92b364c.tar.xz
Server-side PSK selection
-rw-r--r--crypto/src/tls/OfferedPsks.cs36
-rw-r--r--crypto/src/tls/TlsUtilities.cs71
2 files changed, 97 insertions, 10 deletions
diff --git a/crypto/src/tls/OfferedPsks.cs b/crypto/src/tls/OfferedPsks.cs
index 9eddd2e23..1cc8a2a68 100644
--- a/crypto/src/tls/OfferedPsks.cs
+++ b/crypto/src/tls/OfferedPsks.cs
@@ -26,6 +26,22 @@ namespace Org.BouncyCastle.Tls
             }
         }
 
+        internal class SelectedConfig
+        {
+            internal readonly int m_index;
+            internal readonly TlsPsk m_psk;
+            internal readonly short[] m_pskKeyExchangeModes;
+            internal readonly TlsSecret m_earlySecret;
+
+            internal SelectedConfig(int index, TlsPsk psk, short[] pskKeyExchangeModes, TlsSecret earlySecret)
+            {
+                this.m_index = index;
+                this.m_psk = psk;
+                this.m_pskKeyExchangeModes = pskKeyExchangeModes;
+                this.m_earlySecret = earlySecret;
+            }
+        }
+
         private readonly IList m_identities;
         private readonly IList m_binders;
         private readonly int m_bindersSize;
@@ -49,16 +65,6 @@ namespace Org.BouncyCastle.Tls
             this.m_bindersSize = bindersSize;
         }
 
-        internal byte[] GetBinderForIdentity(PskIdentity matchIdentity)
-        {
-            for (int i = 0, count = m_identities.Count; i < count; ++i)
-            {
-                if (matchIdentity.Equals(m_identities[i]))
-                    return (byte[])m_binders[i];
-            }
-            return null;
-        }
-
         public IList Binders
         {
             get { return m_binders; }
@@ -74,6 +80,16 @@ namespace Org.BouncyCastle.Tls
             get { return m_identities; }
         }
 
+        public int GetIndexOfIdentity(PskIdentity pskIdentity)
+        {
+            for (int i = 0, count = m_identities.Count; i < count; ++i)
+            {
+                if (pskIdentity.Equals(m_identities[i]))
+                    return i;
+            }
+            return -1;
+        }
+
         /// <exception cref="IOException"/>
         public void Encode(Stream output)
         {
diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs
index b03548398..d1e046965 100644
--- a/crypto/src/tls/TlsUtilities.cs
+++ b/crypto/src/tls/TlsUtilities.cs
@@ -5526,6 +5526,77 @@ namespace Org.BouncyCastle.Tls
             return result;
         }
 
+        internal static OfferedPsks.SelectedConfig SelectPreSharedKey(TlsServerContext serverContext, TlsServer server,
+            IDictionary clientHelloExtensions, HandshakeMessageInput clientHelloMessage, TlsHandshakeHash handshakeHash,
+            bool afterHelloRetryRequest)
+        {
+            bool handshakeHashUpdated = false;
+
+            OfferedPsks offeredPsks = TlsExtensionsUtilities.GetPreSharedKeyClientHello(clientHelloExtensions);
+            if (null != offeredPsks)
+            {
+                short[] pskKeyExchangeModes = TlsExtensionsUtilities.GetPskKeyExchangeModesExtension(
+                    clientHelloExtensions);
+                if (IsNullOrEmpty(pskKeyExchangeModes))
+                    throw new TlsFatalAlert(AlertDescription.missing_extension);
+
+                // TODO[tls13] Add support for psk_ke?
+                if (Arrays.Contains(pskKeyExchangeModes, PskKeyExchangeMode.psk_dhe_ke))
+                {
+                    // TODO[tls13] Prefer to get the exact index from the server?
+                    TlsPskExternal psk = server.GetExternalPsk(offeredPsks.Identities);
+                    if (null != psk)
+                    {
+                        int index = offeredPsks.GetIndexOfIdentity(new PskIdentity(psk.Identity, 0L));
+                        if (index >= 0)
+                        {
+                            byte[] binder = (byte[])offeredPsks.Binders[index];
+
+                            TlsCrypto crypto = serverContext.Crypto;
+                            TlsSecret earlySecret = GetPskEarlySecret(crypto, psk);
+
+                            // TODO[tls13-psk] Handle resumption PSKs
+                            bool isExternalPsk = true;
+                            int pskCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm);
+
+                            byte[] transcriptHash;
+                            {
+                                handshakeHashUpdated = true;
+                                int bindersSize = offeredPsks.BindersSize;
+                                clientHelloMessage.UpdateHashPrefix(handshakeHash, bindersSize);
+
+                                if (afterHelloRetryRequest)
+                                {
+                                    transcriptHash = handshakeHash.GetFinalHash(pskCryptoHashAlgorithm);
+                                }
+                                else
+                                {
+                                    TlsHash hash = crypto.CreateHash(pskCryptoHashAlgorithm);
+                                    handshakeHash.CopyBufferTo(new TlsHashSink(hash));
+                                    transcriptHash = hash.CalculateHash();
+                                }
+
+                                clientHelloMessage.UpdateHashSuffix(handshakeHash, bindersSize);
+                            }
+
+                            byte[] calculatedBinder = CalculatePskBinder(crypto, isExternalPsk, pskCryptoHashAlgorithm,
+                                earlySecret, transcriptHash);
+
+                            if (Arrays.ConstantTimeAreEqual(calculatedBinder, binder))
+                                return new OfferedPsks.SelectedConfig(index, psk, pskKeyExchangeModes, earlySecret);
+                        }
+                    }
+                }
+            }
+
+            if (!handshakeHashUpdated)
+            {
+                clientHelloMessage.UpdateHash(handshakeHash);
+            }
+
+            return null;
+        }
+
         internal static TlsSecret GetPskEarlySecret(TlsCrypto crypto, TlsPsk psk)
         {
             int cryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm);