summary refs log tree commit diff
path: root/crypto/src/crypto/tls/DefaultTlsAgreementCredentials.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/crypto/tls/DefaultTlsAgreementCredentials.cs')
-rw-r--r--crypto/src/crypto/tls/DefaultTlsAgreementCredentials.cs104
1 files changed, 53 insertions, 51 deletions
diff --git a/crypto/src/crypto/tls/DefaultTlsAgreementCredentials.cs b/crypto/src/crypto/tls/DefaultTlsAgreementCredentials.cs
index 2dfe526d1..5147a1990 100644
--- a/crypto/src/crypto/tls/DefaultTlsAgreementCredentials.cs
+++ b/crypto/src/crypto/tls/DefaultTlsAgreementCredentials.cs
@@ -1,4 +1,5 @@
 using System;
+using System.IO;
 
 using Org.BouncyCastle.Crypto.Agreement;
 using Org.BouncyCastle.Crypto.Parameters;
@@ -7,61 +8,62 @@ using Org.BouncyCastle.Utilities;
 
 namespace Org.BouncyCastle.Crypto.Tls
 {
-	public class DefaultTlsAgreementCredentials
-		: TlsAgreementCredentials
-	{
-		protected Certificate clientCert;
-		protected AsymmetricKeyParameter clientPrivateKey;
+    public class DefaultTlsAgreementCredentials
+        : AbstractTlsAgreementCredentials
+    {
+        protected readonly Certificate mCertificate;
+        protected readonly AsymmetricKeyParameter mPrivateKey;
 
-		protected IBasicAgreement basicAgreement;
+        protected readonly IBasicAgreement mBasicAgreement;
+        protected readonly bool mTruncateAgreement;
 
-		public DefaultTlsAgreementCredentials(Certificate clientCertificate, AsymmetricKeyParameter clientPrivateKey)
-		{
-			if (clientCertificate == null)
-			{
-				throw new ArgumentNullException("clientCertificate");
-			}
-			if (clientCertificate.certs.Length == 0)
-			{
-				throw new ArgumentException("cannot be empty", "clientCertificate");
-			}
-			if (clientPrivateKey == null)
-			{
-				throw new ArgumentNullException("clientPrivateKey");
-			}
-			if (!clientPrivateKey.IsPrivate)
-			{
-				throw new ArgumentException("must be private", "clientPrivateKey");
-			}
+        public DefaultTlsAgreementCredentials(Certificate certificate, AsymmetricKeyParameter privateKey)
+        {
+            if (certificate == null)
+                throw new ArgumentNullException("certificate");
+            if (certificate.IsEmpty)
+                throw new ArgumentException("cannot be empty", "certificate");
+            if (privateKey == null)
+                throw new ArgumentNullException("privateKey");
+            if (!privateKey.IsPrivate)
+                throw new ArgumentException("must be private", "privateKey");
 
-			if (clientPrivateKey is DHPrivateKeyParameters)
-			{
-				basicAgreement = new DHBasicAgreement();
-			}
-			else if (clientPrivateKey is ECPrivateKeyParameters)
-			{
-				basicAgreement = new ECDHBasicAgreement();
-			}
-			else
-			{
-				throw new ArgumentException("type not supported: "
-					+ clientPrivateKey.GetType().FullName, "clientPrivateKey");
-			}
+            if (privateKey is DHPrivateKeyParameters)
+            {
+                mBasicAgreement = new DHBasicAgreement();
+                mTruncateAgreement = true;
+            }
+            else if (privateKey is ECPrivateKeyParameters)
+            {
+                mBasicAgreement = new ECDHBasicAgreement();
+                mTruncateAgreement = false;
+            }
+            else
+            {
+                throw new ArgumentException("type not supported: " + privateKey.GetType().FullName, "privateKey");
+            }
 
-			this.clientCert = clientCertificate;
-			this.clientPrivateKey = clientPrivateKey;
-		}
+            this.mCertificate = certificate;
+            this.mPrivateKey = privateKey;
+        }
 
-		public virtual Certificate Certificate
-		{
-			get { return clientCert; }
-		}
+        public override Certificate Certificate
+        {
+            get { return mCertificate; }
+        }
 
-		public virtual byte[] GenerateAgreement(AsymmetricKeyParameter serverPublicKey)
-		{
-			basicAgreement.Init(clientPrivateKey);
-			BigInteger agreementValue = basicAgreement.CalculateAgreement(serverPublicKey);
-			return BigIntegers.AsUnsignedByteArray(agreementValue);
-		}
-	}
+        /// <exception cref="IOException"></exception>
+        public override byte[] GenerateAgreement(AsymmetricKeyParameter peerPublicKey)
+        {
+            mBasicAgreement.Init(mPrivateKey);
+            BigInteger agreementValue = mBasicAgreement.CalculateAgreement(peerPublicKey);
+
+            if (mTruncateAgreement)
+            {
+                return BigIntegers.AsUnsignedByteArray(agreementValue);
+            }
+
+            return BigIntegers.AsUnsignedByteArray(mBasicAgreement.GetFieldSize(), agreementValue);
+        }
+    }
 }