summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crypto/crypto.csproj15
-rw-r--r--crypto/test/src/crypto/tls/test/MockPskTlsClient.cs132
-rw-r--r--crypto/test/src/crypto/tls/test/MockPskTlsServer.cs105
-rw-r--r--crypto/test/src/crypto/tls/test/TlsPskProtocolTest.cs80
4 files changed, 332 insertions, 0 deletions
diff --git a/crypto/crypto.csproj b/crypto/crypto.csproj
index 66a50bb39..84844e827 100644
--- a/crypto/crypto.csproj
+++ b/crypto/crypto.csproj
@@ -11033,6 +11033,16 @@
                     BuildAction = "Compile"
                 />
                 <File
+                    RelPath = "test\src\crypto\tls\test\MockPskTlsClient.cs"
+                    SubType = "Code"
+                    BuildAction = "Compile"
+                />
+                <File
+                    RelPath = "test\src\crypto\tls\test\MockPskTlsServer.cs"
+                    SubType = "Code"
+                    BuildAction = "Compile"
+                />
+                <File
                     RelPath = "test\src\crypto\tls\test\MockTlsClient.cs"
                     SubType = "Code"
                     BuildAction = "Compile"
@@ -11058,6 +11068,11 @@
                     BuildAction = "Compile"
                 />
                 <File
+                    RelPath = "test\src\crypto\tls\test\TlsPskProtocolTest.cs"
+                    SubType = "Code"
+                    BuildAction = "Compile"
+                />
+                <File
                     RelPath = "test\src\crypto\tls\test\TlsServerTest.cs"
                     SubType = "Code"
                     BuildAction = "Compile"
diff --git a/crypto/test/src/crypto/tls/test/MockPskTlsClient.cs b/crypto/test/src/crypto/tls/test/MockPskTlsClient.cs
new file mode 100644
index 000000000..4e183cba1
--- /dev/null
+++ b/crypto/test/src/crypto/tls/test/MockPskTlsClient.cs
@@ -0,0 +1,132 @@
+using System;
+using System.Collections;
+using System.IO;
+
+using Org.BouncyCastle.Asn1.X509;
+using Org.BouncyCastle.Utilities;
+using Org.BouncyCastle.Utilities.Encoders;
+
+namespace Org.BouncyCastle.Crypto.Tls.Tests
+{
+    internal class MockPskTlsClient
+        :   PskTlsClient
+    {
+        internal TlsSession mSession;
+
+        internal MockPskTlsClient(TlsSession session)
+            :   this(session, new BasicTlsPskIdentity("client", new byte[16]))
+        {
+        }
+
+        internal MockPskTlsClient(TlsSession session, TlsPskIdentity pskIdentity)
+            :   base(pskIdentity)
+        {
+            this.mSession = session;
+        }
+
+        public override TlsSession GetSessionToResume()
+        {
+            return this.mSession;
+        }
+
+        public override void NotifyAlertRaised(byte alertLevel, byte alertDescription, string message, Exception cause)
+        {
+            TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+            output.WriteLine("TLS-PSK client raised alert: " + AlertLevel.GetText(alertLevel)
+                + ", " + AlertDescription.GetText(alertDescription));
+            if (message != null)
+            {
+                output.WriteLine("> " + message);
+            }
+            if (cause != null)
+            {
+                output.WriteLine(cause);
+            }
+        }
+
+        public override void NotifyAlertReceived(byte alertLevel, byte alertDescription)
+        {
+            TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+            output.WriteLine("TLS-PSK client received alert: " + AlertLevel.GetText(alertLevel)
+                + ", " + AlertDescription.GetText(alertDescription));
+        }
+
+        public override void NotifyHandshakeComplete()
+        {
+            base.NotifyHandshakeComplete();
+
+            TlsSession newSession = mContext.ResumableSession;
+            if (newSession != null)
+            {
+                byte[] newSessionID = newSession.SessionID;
+                string hex = Hex.ToHexString(newSessionID);
+
+                if (this.mSession != null && Arrays.AreEqual(this.mSession.SessionID, newSessionID))
+                {
+                    Console.WriteLine("Resumed session: " + hex);
+                }
+                else
+                {
+                    Console.WriteLine("Established session: " + hex);
+                }
+
+                this.mSession = newSession;
+            }
+        }
+
+        public override int[] GetCipherSuites()
+        {
+            return new int[]{ CipherSuite.TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384,
+                CipherSuite.TLS_DHE_PSK_WITH_AES_256_CBC_SHA384, CipherSuite.TLS_RSA_PSK_WITH_AES_256_CBC_SHA384,
+                CipherSuite.TLS_PSK_WITH_AES_256_CBC_SHA };
+        }
+
+        public override ProtocolVersion MinimumVersion
+        {
+	        get { return ProtocolVersion.TLSv12; }
+        }
+
+        public override IDictionary GetClientExtensions()
+        {
+            IDictionary clientExtensions = TlsExtensionsUtilities.EnsureExtensionsInitialised(base.GetClientExtensions());
+            TlsExtensionsUtilities.AddEncryptThenMacExtension(clientExtensions);
+            return clientExtensions;
+        }
+
+        public override void NotifyServerVersion(ProtocolVersion serverVersion)
+        {
+            base.NotifyServerVersion(serverVersion);
+
+            Console.WriteLine("TLS-PSK client negotiated " + serverVersion);
+        }
+
+        public override TlsAuthentication GetAuthentication()
+        {
+            return new MyTlsAuthentication(mContext);
+        }
+
+        internal class MyTlsAuthentication
+            :   ServerOnlyTlsAuthentication
+        {
+            private readonly TlsContext mContext;
+
+            internal MyTlsAuthentication(TlsContext context)
+            {
+                this.mContext = context;
+            }
+
+            public override void NotifyServerCertificate(Certificate serverCertificate)
+            {
+                X509CertificateStructure[] chain = serverCertificate.GetCertificateList();
+                Console.WriteLine("TLS-PSK client received server certificate chain of length " + chain.Length);
+                for (int i = 0; i != chain.Length; i++)
+                {
+                    X509CertificateStructure entry = chain[i];
+                    // TODO Create Fingerprint based on certificate signature algorithm digest
+                    Console.WriteLine("    Fingerprint:SHA-256 " + TlsTestUtilities.Fingerprint(entry) + " ("
+                        + entry.Subject + ")");
+                }
+            }
+        };
+    }
+}
diff --git a/crypto/test/src/crypto/tls/test/MockPskTlsServer.cs b/crypto/test/src/crypto/tls/test/MockPskTlsServer.cs
new file mode 100644
index 000000000..7394a2077
--- /dev/null
+++ b/crypto/test/src/crypto/tls/test/MockPskTlsServer.cs
@@ -0,0 +1,105 @@
+using System;
+using System.Collections;
+using System.IO;
+
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Crypto.Tls.Tests
+{
+    internal class MockPskTlsServer
+        :   PskTlsServer
+    {
+        internal MockPskTlsServer()
+            :   base(new MyIdentityManager())
+        {
+        }
+
+        public override void NotifyAlertRaised(byte alertLevel, byte alertDescription, string message, Exception cause)
+        {
+            TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+            output.WriteLine("TLS-PSK server raised alert: " + AlertLevel.GetText(alertLevel)
+                + ", " + AlertDescription.GetText(alertDescription));
+            if (message != null)
+            {
+                output.WriteLine("> " + message);
+            }
+            if (cause != null)
+            {
+                output.WriteLine(cause);
+            }
+        }
+
+        public override void NotifyAlertReceived(byte alertLevel, byte alertDescription)
+        {
+            TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+            output.WriteLine("TLS-PSK server received alert: " + AlertLevel.GetText(alertLevel)
+                + ", " + AlertDescription.GetText(alertDescription));
+        }
+
+        public override void NotifyHandshakeComplete()
+        {
+            base.NotifyHandshakeComplete();
+
+            byte[] pskIdentity = mContext.SecurityParameters.PskIdentity;
+            if (pskIdentity != null)
+            {
+                string name = Strings.FromUtf8ByteArray(pskIdentity);
+                Console.WriteLine("TLS-PSK server completed handshake for PSK identity: " + name);
+            }
+        }
+
+        protected override int[] GetCipherSuites()
+        {
+            return new int[]{ CipherSuite.TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384,
+                CipherSuite.TLS_DHE_PSK_WITH_AES_256_CBC_SHA384, CipherSuite.TLS_RSA_PSK_WITH_AES_256_CBC_SHA384,
+                CipherSuite.TLS_PSK_WITH_AES_256_CBC_SHA };
+        }
+
+        protected override ProtocolVersion MaximumVersion
+        {
+            get { return ProtocolVersion.TLSv12; }
+        }
+
+        protected override ProtocolVersion MinimumVersion
+        {
+            get { return ProtocolVersion.TLSv12; }
+        }
+
+        public override ProtocolVersion GetServerVersion()
+        {
+            ProtocolVersion serverVersion = base.GetServerVersion();
+
+            Console.WriteLine("TLS-PSK server negotiated " + serverVersion);
+
+            return serverVersion;
+        }
+
+        protected override TlsEncryptionCredentials GetRsaEncryptionCredentials()
+        {
+            return TlsTestUtilities.LoadEncryptionCredentials(mContext, new string[]{"x509-server.pem", "x509-ca.pem"},
+                "x509-server-key.pem");
+        }
+
+        internal class MyIdentityManager
+            :   TlsPskIdentityManager
+        {
+            public virtual byte[] GetHint()
+            {
+                return Strings.ToUtf8ByteArray("hint");
+            }
+
+            public virtual byte[] GetPsk(byte[] identity)
+            {
+                if (identity != null)
+                {
+                    string name = Strings.FromUtf8ByteArray(identity);
+                    if (name.Equals("client"))
+                    {
+                        return new byte[16];
+                    }
+                }
+                return null;
+            }
+        }
+    }
+}
diff --git a/crypto/test/src/crypto/tls/test/TlsPskProtocolTest.cs b/crypto/test/src/crypto/tls/test/TlsPskProtocolTest.cs
new file mode 100644
index 000000000..b059bb2cb
--- /dev/null
+++ b/crypto/test/src/crypto/tls/test/TlsPskProtocolTest.cs
@@ -0,0 +1,80 @@
+using System;
+using System.IO;
+using System.Threading;
+
+using Org.BouncyCastle.Security;
+using Org.BouncyCastle.Utilities;
+using Org.BouncyCastle.Utilities.IO;
+
+using NUnit.Framework;
+
+namespace Org.BouncyCastle.Crypto.Tls.Tests
+{
+    [TestFixture]
+    public class TlsPskProtocolTest
+    {
+        [Test]
+        public void TestClientServer()
+        {
+            SecureRandom secureRandom = new SecureRandom();
+
+            PipedStream clientPipe = new PipedStream();
+            PipedStream serverPipe = new PipedStream(clientPipe);
+
+            TlsClientProtocol clientProtocol = new TlsClientProtocol(clientPipe, secureRandom);
+            TlsServerProtocol serverProtocol = new TlsServerProtocol(serverPipe, secureRandom);
+
+            Server server = new Server(serverProtocol);
+
+            Thread serverThread = new Thread(new ThreadStart(server.Run));
+            serverThread.Start();
+
+            MockPskTlsClient client = new MockPskTlsClient(null);
+            clientProtocol.Connect(client);
+
+            // NOTE: Because we write-all before we read-any, this length can't be more than the pipe capacity
+            int length = 1000;
+
+            byte[] data = new byte[length];
+            secureRandom.NextBytes(data);
+
+            Stream output = clientProtocol.Stream;
+            output.Write(data, 0, data.Length);
+
+            byte[] echo = new byte[data.Length];
+            int count = Streams.ReadFully(clientProtocol.Stream, echo);
+
+            Assert.AreEqual(count, data.Length);
+            Assert.IsTrue(Arrays.AreEqual(data, echo));
+
+            output.Close();
+
+            serverThread.Join();
+        }
+
+        internal class Server
+        {
+            private readonly TlsServerProtocol mServerProtocol;
+
+            internal Server(TlsServerProtocol serverProtocol)
+            {
+                this.mServerProtocol = serverProtocol;
+            }
+
+            public void Run()
+            {
+                try
+                {
+                    MockPskTlsServer server = new MockPskTlsServer();
+                    mServerProtocol.Accept(server);
+                    Streams.PipeAll(mServerProtocol.Stream, mServerProtocol.Stream);
+                    mServerProtocol.Close();
+                }
+                catch (Exception)
+                {
+                    //throw new RuntimeException(e);
+                }
+            }
+        }
+    }
+}