summary refs log tree commit diff
path: root/crypto/src/tls/TlsPskKeyExchange.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/tls/TlsPskKeyExchange.cs')
-rw-r--r--crypto/src/tls/TlsPskKeyExchange.cs305
1 files changed, 305 insertions, 0 deletions
diff --git a/crypto/src/tls/TlsPskKeyExchange.cs b/crypto/src/tls/TlsPskKeyExchange.cs
new file mode 100644
index 000000000..1055fdc53
--- /dev/null
+++ b/crypto/src/tls/TlsPskKeyExchange.cs
@@ -0,0 +1,305 @@
+using System;
+using System.IO;
+
+using Org.BouncyCastle.Tls.Crypto;
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Tls
+{
+    /// <summary>(D)TLS PSK key exchange (RFC 4279).</summary>
+    public class TlsPskKeyExchange
+        : AbstractTlsKeyExchange
+    {
+        private static int CheckKeyExchange(int keyExchange)
+        {
+            switch (keyExchange)
+            {
+            case KeyExchangeAlgorithm.DHE_PSK:
+            case KeyExchangeAlgorithm.ECDHE_PSK:
+            case KeyExchangeAlgorithm.PSK:
+            case KeyExchangeAlgorithm.RSA_PSK:
+                return keyExchange;
+            default:
+                throw new ArgumentException("unsupported key exchange algorithm", "keyExchange");
+            }
+        }
+
+        protected TlsPskIdentity m_pskIdentity;
+        protected TlsPskIdentityManager m_pskIdentityManager;
+        protected TlsDHGroupVerifier m_dhGroupVerifier;
+
+        protected byte[] m_psk_identity_hint = null;
+        protected byte[] m_psk = null;
+
+        protected TlsDHConfig m_dhConfig;
+        protected TlsECConfig m_ecConfig;
+        protected TlsAgreement m_agreement;
+
+        protected TlsCredentialedDecryptor m_serverCredentials = null;
+        protected TlsCertificate m_serverCertificate;
+        protected TlsSecret m_preMasterSecret;
+
+        public TlsPskKeyExchange(int keyExchange, TlsPskIdentity pskIdentity, TlsDHGroupVerifier dhGroupVerifier)
+            : this(keyExchange, pskIdentity, null, dhGroupVerifier, null, null)
+        {
+        }
+
+        public TlsPskKeyExchange(int keyExchange, TlsPskIdentityManager pskIdentityManager,
+            TlsDHConfig dhConfig, TlsECConfig ecConfig)
+            : this(keyExchange, null, pskIdentityManager, null, dhConfig, ecConfig)
+        {
+        }
+
+        private TlsPskKeyExchange(int keyExchange, TlsPskIdentity pskIdentity, TlsPskIdentityManager pskIdentityManager,
+            TlsDHGroupVerifier dhGroupVerifier, TlsDHConfig dhConfig, TlsECConfig ecConfig)
+            : base(CheckKeyExchange(keyExchange))
+        {
+            this.m_pskIdentity = pskIdentity;
+            this.m_pskIdentityManager = pskIdentityManager;
+            this.m_dhGroupVerifier = dhGroupVerifier;
+            this.m_dhConfig = dhConfig;
+            this.m_ecConfig = ecConfig;
+        }
+
+        public override void SkipServerCredentials()
+        {
+            if (m_keyExchange == KeyExchangeAlgorithm.RSA_PSK)
+                throw new TlsFatalAlert(AlertDescription.internal_error);
+        }
+
+        public override void ProcessServerCredentials(TlsCredentials serverCredentials)
+        {
+            if (m_keyExchange != KeyExchangeAlgorithm.RSA_PSK)
+                throw new TlsFatalAlert(AlertDescription.internal_error);
+
+            this.m_serverCredentials = TlsUtilities.RequireDecryptorCredentials(serverCredentials);
+        }
+
+        public override void ProcessServerCertificate(Certificate serverCertificate)
+        {
+            if (m_keyExchange != KeyExchangeAlgorithm.RSA_PSK)
+                throw new TlsFatalAlert(AlertDescription.unexpected_message);
+
+            this.m_serverCertificate = serverCertificate.GetCertificateAt(0).CheckUsageInRole(ConnectionEnd.server,
+                TlsCertificateRole.RsaEncryption);
+        }
+
+        public override byte[] GenerateServerKeyExchange()
+        {
+            this.m_psk_identity_hint = m_pskIdentityManager.GetHint();
+
+            if (this.m_psk_identity_hint == null && !RequiresServerKeyExchange)
+                return null;
+
+            MemoryStream buf = new MemoryStream();
+
+            if (this.m_psk_identity_hint == null)
+            {
+                TlsUtilities.WriteOpaque16(TlsUtilities.EmptyBytes, buf);
+            }
+            else
+            {
+                TlsUtilities.WriteOpaque16(this.m_psk_identity_hint, buf);
+            }
+
+            if (this.m_keyExchange == KeyExchangeAlgorithm.DHE_PSK)
+            {
+                if (this.m_dhConfig == null)
+                    throw new TlsFatalAlert(AlertDescription.internal_error);
+
+                TlsDHUtilities.WriteDHConfig(m_dhConfig, buf);
+
+                this.m_agreement = m_context.Crypto.CreateDHDomain(m_dhConfig).CreateDH();
+
+                GenerateEphemeralDH(buf);
+            }
+            else if (this.m_keyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
+            {
+                if (this.m_ecConfig == null)
+                    throw new TlsFatalAlert(AlertDescription.internal_error);
+
+                TlsEccUtilities.WriteECConfig(m_ecConfig, buf);
+
+                this.m_agreement = m_context.Crypto.CreateECDomain(m_ecConfig).CreateECDH();
+
+                GenerateEphemeralECDH(buf);
+            }
+
+            return buf.ToArray();
+        }
+
+        public override bool RequiresServerKeyExchange
+        {
+            get
+            {
+                switch (m_keyExchange)
+                {
+                case KeyExchangeAlgorithm.DHE_PSK:
+                case KeyExchangeAlgorithm.ECDHE_PSK:
+                    return true;
+                default:
+                    return false;
+                }
+            }
+        }
+
+        public override void ProcessServerKeyExchange(Stream input)
+        {
+            this.m_psk_identity_hint = TlsUtilities.ReadOpaque16(input);
+
+            if (this.m_keyExchange == KeyExchangeAlgorithm.DHE_PSK)
+            {
+                this.m_dhConfig = TlsDHUtilities.ReceiveDHConfig(m_context, m_dhGroupVerifier, input);
+
+                byte[] y = TlsUtilities.ReadOpaque16(input, 1);
+
+                this.m_agreement = m_context.Crypto.CreateDHDomain(m_dhConfig).CreateDH();
+
+                ProcessEphemeralDH(y);
+            }
+            else if (this.m_keyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
+            {
+                this.m_ecConfig = TlsEccUtilities.ReceiveECDHConfig(m_context, input);
+
+                byte[] point = TlsUtilities.ReadOpaque8(input, 1);
+
+                this.m_agreement = m_context.Crypto.CreateECDomain(m_ecConfig).CreateECDH();
+
+                ProcessEphemeralECDH(point);
+            }
+        }
+
+        public override void ProcessClientCredentials(TlsCredentials clientCredentials)
+        {
+            throw new TlsFatalAlert(AlertDescription.internal_error);
+        }
+
+        public override void GenerateClientKeyExchange(Stream output)
+        {
+            if (m_psk_identity_hint == null)
+            {
+                m_pskIdentity.SkipIdentityHint();
+            }
+            else
+            {
+                m_pskIdentity.NotifyIdentityHint(m_psk_identity_hint);
+            }
+
+            byte[] psk_identity = m_pskIdentity.GetPskIdentity();
+            if (psk_identity == null)
+                throw new TlsFatalAlert(AlertDescription.internal_error);
+
+            this.m_psk = m_pskIdentity.GetPsk();
+            if (m_psk == null)
+                throw new TlsFatalAlert(AlertDescription.internal_error);
+
+            TlsUtilities.WriteOpaque16(psk_identity, output);
+
+            m_context.SecurityParameters.m_pskIdentity = Arrays.Clone(psk_identity);
+
+            if (this.m_keyExchange == KeyExchangeAlgorithm.DHE_PSK)
+            {
+                GenerateEphemeralDH(output);
+            }
+            else if (this.m_keyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
+            {
+                GenerateEphemeralECDH(output);
+            }
+            else if (this.m_keyExchange == KeyExchangeAlgorithm.RSA_PSK)
+            {
+                this.m_preMasterSecret = TlsRsaUtilities.GenerateEncryptedPreMasterSecret(m_context,
+                    m_serverCertificate, output);
+            }
+        }
+
+        public override void ProcessClientKeyExchange(Stream input)
+        {
+            byte[] psk_identity = TlsUtilities.ReadOpaque16(input);
+
+            this.m_psk = m_pskIdentityManager.GetPsk(psk_identity);
+            if (m_psk == null)
+                throw new TlsFatalAlert(AlertDescription.unknown_psk_identity);
+
+            m_context.SecurityParameters.m_pskIdentity = psk_identity;
+
+            if (this.m_keyExchange == KeyExchangeAlgorithm.DHE_PSK)
+            {
+                byte[] y = TlsUtilities.ReadOpaque16(input, 1);
+
+                ProcessEphemeralDH(y);
+            }
+            else if (this.m_keyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
+            {
+                byte[] point = TlsUtilities.ReadOpaque8(input, 1);
+
+                ProcessEphemeralECDH(point);
+            }
+            else if (this.m_keyExchange == KeyExchangeAlgorithm.RSA_PSK)
+            {
+                byte[] encryptedPreMasterSecret = TlsUtilities.ReadEncryptedPms(m_context, input);
+
+                this.m_preMasterSecret = m_serverCredentials.Decrypt(new TlsCryptoParameters(m_context),
+                    encryptedPreMasterSecret);
+            }
+        }
+
+        public override TlsSecret GeneratePreMasterSecret()
+        {
+            byte[] other_secret = GenerateOtherSecret(m_psk.Length);
+
+            MemoryStream buf = new MemoryStream(4 + other_secret.Length + m_psk.Length);
+            TlsUtilities.WriteOpaque16(other_secret, buf);
+            TlsUtilities.WriteOpaque16(m_psk, buf);
+
+            Array.Clear(m_psk, 0, m_psk.Length);
+            this.m_psk = null;
+
+            return m_context.Crypto.CreateSecret(buf.ToArray());
+        }
+
+        protected virtual void GenerateEphemeralDH(Stream output)
+        {
+            byte[] y = m_agreement.GenerateEphemeral();
+            TlsUtilities.WriteOpaque16(y, output);
+        }
+
+        protected virtual void GenerateEphemeralECDH(Stream output)
+        {
+            byte[] point = m_agreement.GenerateEphemeral();
+            TlsUtilities.WriteOpaque8(point, output);
+        }
+
+        protected virtual byte[] GenerateOtherSecret(int pskLength)
+        {
+            if (this.m_keyExchange == KeyExchangeAlgorithm.PSK)
+                return new byte[pskLength];
+
+            if (this.m_keyExchange == KeyExchangeAlgorithm.DHE_PSK ||
+                this.m_keyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
+            {
+                if (m_agreement != null)
+                    return m_agreement.CalculateSecret().Extract();
+            }
+
+            if (this.m_keyExchange == KeyExchangeAlgorithm.RSA_PSK)
+            {
+                if (m_preMasterSecret != null)
+                    return this.m_preMasterSecret.Extract();
+            }
+
+            throw new TlsFatalAlert(AlertDescription.internal_error);
+        }
+
+        protected virtual void ProcessEphemeralDH(byte[] y)
+        {
+            this.m_agreement.ReceivePeerValue(y);
+        }
+
+        protected virtual void ProcessEphemeralECDH(byte[] point)
+        {
+            TlsEccUtilities.CheckPointEncoding(m_ecConfig.NamedGroup, point);
+
+            this.m_agreement.ReceivePeerValue(point);
+        }
+    }
+}