summary refs log tree commit diff
path: root/crypto/test/src
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/test/src')
-rw-r--r--crypto/test/src/tls/test/MockPskTls13Client.cs110
-rw-r--r--crypto/test/src/tls/test/PskTls13ClientTest.cs85
2 files changed, 195 insertions, 0 deletions
diff --git a/crypto/test/src/tls/test/MockPskTls13Client.cs b/crypto/test/src/tls/test/MockPskTls13Client.cs
new file mode 100644
index 000000000..d8be1fddd
--- /dev/null
+++ b/crypto/test/src/tls/test/MockPskTls13Client.cs
@@ -0,0 +1,110 @@
+using System;
+using System.Collections;
+using System.IO;
+
+using Org.BouncyCastle.Tls.Crypto;
+using Org.BouncyCastle.Tls.Crypto.Impl.BC;
+using Org.BouncyCastle.Security;
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+    internal class MockPskTls13Client
+        : AbstractTlsClient
+    {
+        internal MockPskTls13Client()
+            : base(new BcTlsCrypto(new SecureRandom()))
+        {
+        }
+
+        //public override IList GetEarlyKeyShareGroups()
+        //{
+        //    return TlsUtilities.VectorOfOne(NamedGroup.secp256r1);
+        //    //return null;
+        //}
+
+        //public override short[] GetPskKeyExchangeModes()
+        //{
+        //    return new short[] { PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke };
+        //}
+
+        protected override IList GetProtocolNames()
+        {
+            IList protocolNames = new ArrayList();
+            protocolNames.Add(ProtocolName.Http_1_1);
+            protocolNames.Add(ProtocolName.Http_2_Tls);
+            return protocolNames;
+        }
+
+        protected override int[] GetSupportedCipherSuites()
+        {
+            return TlsUtilities.GetSupportedCipherSuites(Crypto, new int[] { CipherSuite.TLS_AES_128_GCM_SHA256 });
+        }
+
+        protected override ProtocolVersion[] GetSupportedVersions()
+        {
+            return ProtocolVersion.TLSv13.Only();
+        }
+
+        public override IList GetExternalPsks()
+        {
+            byte[] identity = Strings.ToUtf8ByteArray("client");
+            TlsSecret key = Crypto.CreateSecret(Strings.ToUtf8ByteArray("TLS_TEST_PSK"));
+            int prfAlgorithm = PrfAlgorithm.tls13_hkdf_sha256;
+
+            return TlsUtilities.VectorOfOne(new BasicTlsPskExternal(identity, key, prfAlgorithm));
+        }
+
+        public override void NotifyAlertRaised(short alertLevel, short alertDescription, string message,
+            Exception cause)
+        {
+            TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+            output.WriteLine("TLS 1.3 PSK client raised alert: " + AlertLevel.GetText(alertLevel)
+                + ", " + AlertDescription.GetText(alertDescription));
+            if (message != null)
+            {
+                output.WriteLine("> " + message);
+            }
+            if (cause != null)
+            {
+                output.WriteLine(cause);
+            }
+        }
+
+        public override void NotifyAlertReceived(short alertLevel, short alertDescription)
+        {
+            TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+            output.WriteLine("TLS 1.3 PSK client received alert: " + AlertLevel.GetText(alertLevel)
+                + ", " + AlertDescription.GetText(alertDescription));
+        }
+
+        public override void NotifySelectedPsk(TlsPsk selectedPsk)
+        {
+            if (null == selectedPsk)
+                throw new TlsFatalAlert(AlertDescription.handshake_failure);
+        }
+
+        public override void NotifyServerVersion(ProtocolVersion serverVersion)
+        {
+            base.NotifyServerVersion(serverVersion);
+
+            Console.WriteLine("TLS 1.3 PSK client negotiated " + serverVersion);
+        }
+
+        public override TlsAuthentication GetAuthentication()
+        {
+            throw new TlsFatalAlert(AlertDescription.internal_error);
+        }
+
+        public override void NotifyHandshakeComplete()
+        {
+            base.NotifyHandshakeComplete();
+
+            ProtocolName protocolName = m_context.SecurityParameters.ApplicationProtocol;
+            if (protocolName != null)
+            {
+                Console.WriteLine("Client ALPN: " + protocolName.GetUtf8Decoding());
+            }
+        }
+    }
+}
diff --git a/crypto/test/src/tls/test/PskTls13ClientTest.cs b/crypto/test/src/tls/test/PskTls13ClientTest.cs
new file mode 100644
index 000000000..6f67b0572
--- /dev/null
+++ b/crypto/test/src/tls/test/PskTls13ClientTest.cs
@@ -0,0 +1,85 @@
+using System;
+using System.IO;
+using System.Net.Sockets;
+using System.Text;
+
+using NUnit.Framework;
+
+using Org.BouncyCastle.Utilities.Date;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+    [TestFixture]
+    public class PskTls13ClientTest
+    {
+        [Test, Ignore]
+        public void TestConnection()
+        {
+            string host = "localhost";
+            int port = 5556;
+
+            long time0 = DateTimeUtilities.CurrentUnixMs();
+
+            MockPskTls13Client client = new MockPskTls13Client();
+            TlsClientProtocol protocol = OpenTlsClientConnection(host, port, client);
+
+            long time1 = DateTimeUtilities.CurrentUnixMs();
+            Console.WriteLine("Elapsed: " + (time1 - time0) + "ms");
+
+            Http11Get(host, port, protocol.Stream);
+
+            protocol.Close();
+        }
+
+        private static void Http11Get(string host, int port, Stream s)
+        {
+            WriteUtf8Line(s, "GET / HTTP/1.1");
+            //WriteUtf8Line(s, "Host: " + host + ":" + port);
+            WriteUtf8Line(s, "");
+            s.Flush();
+
+            Console.WriteLine("---");
+
+            string[] ends = new string[] { "</HTML>", "HTTP/1.1 3", "HTTP/1.1 4" };
+
+            StreamReader reader = new StreamReader(s);
+
+            bool finished = false;
+            string line;
+            while (!finished && (line = reader.ReadLine()) != null)
+            {
+                Console.WriteLine("<<< " + line);
+
+                string upperLine = TlsTestUtilities.ToUpperInvariant(line);
+
+                // TEST CODE ONLY. This is not a robust way of parsing the result!
+                foreach (string end in ends)
+                {
+                    if (upperLine.IndexOf(end) >= 0)
+                    {
+                        finished = true;
+                        break;
+                    }
+                }
+            }
+
+            Console.Out.Flush();
+        }
+
+        private static TlsClientProtocol OpenTlsClientConnection(string hostname, int port, TlsClient client)
+        {
+            TcpClient tcp = new TcpClient(hostname, port);
+
+            TlsClientProtocol protocol = new TlsClientProtocol(tcp.GetStream());
+            protocol.Connect(client);
+            return protocol;
+        }
+
+        private static void WriteUtf8Line(Stream output, string line)
+        {
+            byte[] buf = Encoding.UTF8.GetBytes(line + "\r\n");
+            output.Write(buf, 0, buf.Length);
+            Console.WriteLine(">>> " + line);
+        }
+    }
+}