diff --git a/crypto/src/crypto/tls/TlsPskKeyExchange.cs b/crypto/src/crypto/tls/TlsPskKeyExchange.cs
index 24bf433dd..cd13e3438 100644
--- a/crypto/src/crypto/tls/TlsPskKeyExchange.cs
+++ b/crypto/src/crypto/tls/TlsPskKeyExchange.cs
@@ -1,68 +1,118 @@
using System;
+using System.Collections;
using System.IO;
using Org.BouncyCastle.Asn1.X509;
using Org.BouncyCastle.Crypto.Parameters;
-using Org.BouncyCastle.Math;
using Org.BouncyCastle.Security;
namespace Org.BouncyCastle.Crypto.Tls
{
- internal class TlsPskKeyExchange
- : TlsKeyExchange
+ /// <summary>(D)TLS PSK key exchange (RFC 4279).</summary>
+ public class TlsPskKeyExchange
+ : AbstractTlsKeyExchange
{
- protected TlsContext context;
- protected int keyExchange;
- protected TlsPskIdentity pskIdentity;
+ protected TlsPskIdentity mPskIdentity;
+ protected DHParameters mDHParameters;
+ protected int[] mNamedCurves;
+ protected byte[] mClientECPointFormats, mServerECPointFormats;
- protected byte[] psk_identity_hint = null;
+ protected byte[] mPskIdentityHint = null;
- protected DHPublicKeyParameters dhAgreeServerPublicKey = null;
- protected DHPrivateKeyParameters dhAgreeClientPrivateKey = null;
+ protected DHPrivateKeyParameters mDHAgreePrivateKey = null;
+ protected DHPublicKeyParameters mDHAgreePublicKey = null;
- protected AsymmetricKeyParameter serverPublicKey = null;
- protected RsaKeyParameters rsaServerPublicKey = null;
- protected byte[] premasterSecret;
+ protected AsymmetricKeyParameter mServerPublicKey = null;
+ protected RsaKeyParameters mRsaServerPublicKey = null;
+ protected TlsEncryptionCredentials mServerCredentials = null;
+ protected byte[] mPremasterSecret;
- internal TlsPskKeyExchange(TlsContext context, int keyExchange,
- TlsPskIdentity pskIdentity)
+ public TlsPskKeyExchange(int keyExchange, IList supportedSignatureAlgorithms, TlsPskIdentity pskIdentity,
+ DHParameters dhParameters, int[] namedCurves, byte[] clientECPointFormats, byte[] serverECPointFormats)
+ : base(keyExchange, supportedSignatureAlgorithms)
{
switch (keyExchange)
{
- case KeyExchangeAlgorithm.PSK:
- case KeyExchangeAlgorithm.RSA_PSK:
- case KeyExchangeAlgorithm.DHE_PSK:
- break;
- default:
- throw new ArgumentException("unsupported key exchange algorithm", "keyExchange");
+ case KeyExchangeAlgorithm.DHE_PSK:
+ case KeyExchangeAlgorithm.ECDHE_PSK:
+ case KeyExchangeAlgorithm.PSK:
+ case KeyExchangeAlgorithm.RSA_PSK:
+ break;
+ default:
+ throw new InvalidOperationException("unsupported key exchange algorithm");
}
- this.context = context;
- this.keyExchange = keyExchange;
- this.pskIdentity = pskIdentity;
+ this.mPskIdentity = pskIdentity;
+ this.mDHParameters = dhParameters;
+ this.mNamedCurves = namedCurves;
+ this.mClientECPointFormats = clientECPointFormats;
+ this.mServerECPointFormats = serverECPointFormats;
}
- public virtual void SkipServerCertificate()
+ public override void SkipServerCredentials()
{
- if (keyExchange == KeyExchangeAlgorithm.RSA_PSK)
- {
+ if (mKeyExchange == KeyExchangeAlgorithm.RSA_PSK)
throw new TlsFatalAlert(AlertDescription.unexpected_message);
- }
}
- public virtual void ProcessServerCertificate(Certificate serverCertificate)
+ public override void ProcessServerCredentials(TlsCredentials serverCredentials)
{
- if (keyExchange != KeyExchangeAlgorithm.RSA_PSK)
+ if (!(serverCredentials is TlsEncryptionCredentials))
+ throw new TlsFatalAlert(AlertDescription.internal_error);
+
+ ProcessServerCertificate(serverCredentials.Certificate);
+
+ this.mServerCredentials = (TlsEncryptionCredentials)serverCredentials;
+ }
+
+ public override byte[] GenerateServerKeyExchange()
+ {
+ // TODO[RFC 4279] Need a server-side PSK API to determine hint and resolve identities to keys
+ this.mPskIdentityHint = null;
+
+ if (this.mPskIdentityHint == null && !RequiresServerKeyExchange)
+ return null;
+
+ MemoryStream buf = new MemoryStream();
+
+ if (this.mPskIdentityHint == null)
{
- throw new TlsFatalAlert(AlertDescription.unexpected_message);
+ TlsUtilities.WriteOpaque16(TlsUtilities.EmptyBytes, buf);
+ }
+ else
+ {
+ TlsUtilities.WriteOpaque16(this.mPskIdentityHint, buf);
}
+ if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
+ {
+ if (this.mDHParameters == null)
+ throw new TlsFatalAlert(AlertDescription.internal_error);
+
+ this.mDHAgreePrivateKey = TlsDHUtilities.GenerateEphemeralServerKeyExchange(context.SecureRandom,
+ this.mDHParameters, buf);
+ }
+ else if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
+ {
+ // TODO[RFC 5489]
+ }
+
+ return buf.ToArray();
+ }
+
+ public override void ProcessServerCertificate(Certificate serverCertificate)
+ {
+ if (mKeyExchange != KeyExchangeAlgorithm.RSA_PSK)
+ throw new TlsFatalAlert(AlertDescription.unexpected_message);
+ if (serverCertificate.IsEmpty)
+ throw new TlsFatalAlert(AlertDescription.bad_certificate);
+
X509CertificateStructure x509Cert = serverCertificate.GetCertificateAt(0);
- SubjectPublicKeyInfo keyInfo = x509Cert.SubjectPublicKeyInfo;
+ SubjectPublicKeyInfo keyInfo = x509Cert.SubjectPublicKeyInfo;
try
{
- this.serverPublicKey = PublicKeyFactory.CreateKey(keyInfo);
+ this.mServerPublicKey = PublicKeyFactory.CreateKey(keyInfo);
}
catch (Exception e)
{
@@ -70,107 +120,92 @@ namespace Org.BouncyCastle.Crypto.Tls
}
// Sanity check the PublicKeyFactory
- if (this.serverPublicKey.IsPrivate)
- {
+ if (this.mServerPublicKey.IsPrivate)
throw new TlsFatalAlert(AlertDescription.internal_error);
- }
- this.rsaServerPublicKey = ValidateRsaPublicKey((RsaKeyParameters)this.serverPublicKey);
+ this.mRsaServerPublicKey = ValidateRsaPublicKey((RsaKeyParameters)this.mServerPublicKey);
TlsUtilities.ValidateKeyUsage(x509Cert, KeyUsage.KeyEncipherment);
- // TODO
- /*
- * Perform various checks per RFC2246 7.4.2: "Unless otherwise specified, the
- * signing algorithm for the certificate must be the same as the algorithm for the
- * certificate key."
- */
+ base.ProcessServerCertificate(serverCertificate);
}
- public virtual void SkipServerKeyExchange()
+ public override bool RequiresServerKeyExchange
{
- if (keyExchange == KeyExchangeAlgorithm.DHE_PSK)
+ get
{
- throw new TlsFatalAlert(AlertDescription.unexpected_message);
+ switch (mKeyExchange)
+ {
+ case KeyExchangeAlgorithm.DHE_PSK:
+ case KeyExchangeAlgorithm.ECDHE_PSK:
+ return true;
+ default:
+ return false;
+ }
}
-
- this.psk_identity_hint = TlsUtilities.EmptyBytes;
}
- public virtual void ProcessServerKeyExchange(Stream input)
+ public override void ProcessServerKeyExchange(Stream input)
{
- this.psk_identity_hint = TlsUtilities.ReadOpaque16(input);
+ this.mPskIdentityHint = TlsUtilities.ReadOpaque16(input);
- if (this.keyExchange == KeyExchangeAlgorithm.DHE_PSK)
+ if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
{
- byte[] pBytes = TlsUtilities.ReadOpaque16(input);
- byte[] gBytes = TlsUtilities.ReadOpaque16(input);
- byte[] YsBytes = TlsUtilities.ReadOpaque16(input);
+ ServerDHParams serverDHParams = ServerDHParams.Parse(input);
- BigInteger p = new BigInteger(1, pBytes);
- BigInteger g = new BigInteger(1, gBytes);
- BigInteger Ys = new BigInteger(1, YsBytes);
-
- this.dhAgreeServerPublicKey = TlsDHUtilities.ValidateDHPublicKey(
- new DHPublicKeyParameters(Ys, new DHParameters(p, g)));
+ this.mDHAgreePublicKey = TlsDHUtilities.ValidateDHPublicKey(serverDHParams.PublicKey);
}
- else if (this.psk_identity_hint.Length == 0)
+ else if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
{
- // TODO Should we enforce that this message should have been skipped if hint is empty?
- //throw new TlsFatalAlert(AlertDescription.unexpected_message);
+ // TODO[RFC 5489]
}
}
- public virtual void ValidateCertificateRequest(CertificateRequest certificateRequest)
+ public override void ValidateCertificateRequest(CertificateRequest certificateRequest)
{
throw new TlsFatalAlert(AlertDescription.unexpected_message);
}
- public virtual void SkipClientCredentials()
- {
- // OK
- }
-
- public virtual void ProcessClientCredentials(TlsCredentials clientCredentials)
+ public override void ProcessClientCredentials(TlsCredentials clientCredentials)
{
throw new TlsFatalAlert(AlertDescription.internal_error);
}
- public virtual void GenerateClientKeyExchange(Stream output)
+ public override void GenerateClientKeyExchange(Stream output)
{
- if (psk_identity_hint == null)
+ if (mPskIdentityHint == null)
{
- pskIdentity.SkipIdentityHint();
+ mPskIdentity.SkipIdentityHint();
}
else
{
- pskIdentity.NotifyIdentityHint(psk_identity_hint);
+ mPskIdentity.NotifyIdentityHint(mPskIdentityHint);
}
- byte[] psk_identity = pskIdentity.GetPskIdentity();
+ byte[] psk_identity = mPskIdentity.GetPskIdentity();
TlsUtilities.WriteOpaque16(psk_identity, output);
- if (this.keyExchange == KeyExchangeAlgorithm.DHE_PSK)
+ if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
{
- this.dhAgreeClientPrivateKey = TlsDHUtilities.GenerateEphemeralClientKeyExchange(
- context.SecureRandom, this.dhAgreeServerPublicKey.Parameters, output);
+ this.mDHAgreePrivateKey = TlsDHUtilities.GenerateEphemeralClientKeyExchange(context.SecureRandom,
+ mDHAgreePublicKey.Parameters, output);
}
- else if (this.keyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
+ else if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
{
// TODO[RFC 5489]
throw new TlsFatalAlert(AlertDescription.internal_error);
}
- else if (this.keyExchange == KeyExchangeAlgorithm.RSA_PSK)
+ else if (this.mKeyExchange == KeyExchangeAlgorithm.RSA_PSK)
{
- this.premasterSecret = TlsRsaUtilities.GenerateEncryptedPreMasterSecret(
- context, this.rsaServerPublicKey, output);
+ this.mPremasterSecret = TlsRsaUtilities.GenerateEncryptedPreMasterSecret(context,
+ this.mRsaServerPublicKey, output);
}
}
- public virtual byte[] GeneratePremasterSecret()
+ public override byte[] GeneratePremasterSecret()
{
- byte[] psk = pskIdentity.GetPsk();
+ byte[] psk = mPskIdentity.GetPsk();
byte[] other_secret = GenerateOtherSecret(psk.Length);
MemoryStream buf = new MemoryStream(4 + other_secret.Length + psk.Length);
@@ -181,14 +216,25 @@ namespace Org.BouncyCastle.Crypto.Tls
protected virtual byte[] GenerateOtherSecret(int pskLength)
{
- if (this.keyExchange == KeyExchangeAlgorithm.DHE_PSK)
+ if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
+ {
+ if (mDHAgreePrivateKey != null)
+ {
+ return TlsDHUtilities.CalculateDHBasicAgreement(mDHAgreePublicKey, mDHAgreePrivateKey);
+ }
+
+ throw new TlsFatalAlert(AlertDescription.internal_error);
+ }
+
+ if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
{
- return TlsDHUtilities.CalculateDHBasicAgreement(dhAgreeServerPublicKey, dhAgreeClientPrivateKey);
+ // TODO[RFC 5489]
+ throw new TlsFatalAlert(AlertDescription.internal_error);
}
- if (this.keyExchange == KeyExchangeAlgorithm.RSA_PSK)
+ if (this.mKeyExchange == KeyExchangeAlgorithm.RSA_PSK)
{
- return this.premasterSecret;
+ return this.mPremasterSecret;
}
return new byte[pskLength];
@@ -197,12 +243,10 @@ namespace Org.BouncyCastle.Crypto.Tls
protected virtual RsaKeyParameters ValidateRsaPublicKey(RsaKeyParameters key)
{
// TODO What is the minimum bit length required?
- // key.Modulus.BitLength;
+ // key.Modulus.BitLength;
if (!key.Exponent.IsProbablePrime(2))
- {
throw new TlsFatalAlert(AlertDescription.illegal_parameter);
- }
return key;
}
|