From 057b9516f6e3d3426f8b2175ac29f99d14166ac9 Mon Sep 17 00:00:00 2001 From: Peter Dettman Date: Tue, 31 Aug 2021 19:18:14 +0700 Subject: Test client for TLS 1.3 (external) PSK --- crypto/crypto.csproj | 10 +++ crypto/test/UnitTests.csproj | 2 + crypto/test/src/tls/test/MockPskTls13Client.cs | 110 +++++++++++++++++++++++++ crypto/test/src/tls/test/PskTls13ClientTest.cs | 85 +++++++++++++++++++ 4 files changed, 207 insertions(+) create mode 100644 crypto/test/src/tls/test/MockPskTls13Client.cs create mode 100644 crypto/test/src/tls/test/PskTls13ClientTest.cs diff --git a/crypto/crypto.csproj b/crypto/crypto.csproj index bb2eed13f..442fda26d 100644 --- a/crypto/crypto.csproj +++ b/crypto/crypto.csproj @@ -14583,6 +14583,11 @@ SubType = "Code" BuildAction = "Compile" /> + + + @@ -491,6 +492,7 @@ + diff --git a/crypto/test/src/tls/test/MockPskTls13Client.cs b/crypto/test/src/tls/test/MockPskTls13Client.cs new file mode 100644 index 000000000..d8be1fddd --- /dev/null +++ b/crypto/test/src/tls/test/MockPskTls13Client.cs @@ -0,0 +1,110 @@ +using System; +using System.Collections; +using System.IO; + +using Org.BouncyCastle.Tls.Crypto; +using Org.BouncyCastle.Tls.Crypto.Impl.BC; +using Org.BouncyCastle.Security; +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Tls.Tests +{ + internal class MockPskTls13Client + : AbstractTlsClient + { + internal MockPskTls13Client() + : base(new BcTlsCrypto(new SecureRandom())) + { + } + + //public override IList GetEarlyKeyShareGroups() + //{ + // return TlsUtilities.VectorOfOne(NamedGroup.secp256r1); + // //return null; + //} + + //public override short[] GetPskKeyExchangeModes() + //{ + // return new short[] { PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke }; + //} + + protected override IList GetProtocolNames() + { + IList protocolNames = new ArrayList(); + protocolNames.Add(ProtocolName.Http_1_1); + protocolNames.Add(ProtocolName.Http_2_Tls); + return protocolNames; + } + + protected override int[] GetSupportedCipherSuites() + { + return TlsUtilities.GetSupportedCipherSuites(Crypto, new int[] { CipherSuite.TLS_AES_128_GCM_SHA256 }); + } + + protected override ProtocolVersion[] GetSupportedVersions() + { + return ProtocolVersion.TLSv13.Only(); + } + + public override IList GetExternalPsks() + { + byte[] identity = Strings.ToUtf8ByteArray("client"); + TlsSecret key = Crypto.CreateSecret(Strings.ToUtf8ByteArray("TLS_TEST_PSK")); + int prfAlgorithm = PrfAlgorithm.tls13_hkdf_sha256; + + return TlsUtilities.VectorOfOne(new BasicTlsPskExternal(identity, key, prfAlgorithm)); + } + + public override void NotifyAlertRaised(short alertLevel, short alertDescription, string message, + Exception cause) + { + TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out; + output.WriteLine("TLS 1.3 PSK client raised alert: " + AlertLevel.GetText(alertLevel) + + ", " + AlertDescription.GetText(alertDescription)); + if (message != null) + { + output.WriteLine("> " + message); + } + if (cause != null) + { + output.WriteLine(cause); + } + } + + public override void NotifyAlertReceived(short alertLevel, short alertDescription) + { + TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out; + output.WriteLine("TLS 1.3 PSK client received alert: " + AlertLevel.GetText(alertLevel) + + ", " + AlertDescription.GetText(alertDescription)); + } + + public override void NotifySelectedPsk(TlsPsk selectedPsk) + { + if (null == selectedPsk) + throw new TlsFatalAlert(AlertDescription.handshake_failure); + } + + public override void NotifyServerVersion(ProtocolVersion serverVersion) + { + base.NotifyServerVersion(serverVersion); + + Console.WriteLine("TLS 1.3 PSK client negotiated " + serverVersion); + } + + public override TlsAuthentication GetAuthentication() + { + throw new TlsFatalAlert(AlertDescription.internal_error); + } + + public override void NotifyHandshakeComplete() + { + base.NotifyHandshakeComplete(); + + ProtocolName protocolName = m_context.SecurityParameters.ApplicationProtocol; + if (protocolName != null) + { + Console.WriteLine("Client ALPN: " + protocolName.GetUtf8Decoding()); + } + } + } +} diff --git a/crypto/test/src/tls/test/PskTls13ClientTest.cs b/crypto/test/src/tls/test/PskTls13ClientTest.cs new file mode 100644 index 000000000..6f67b0572 --- /dev/null +++ b/crypto/test/src/tls/test/PskTls13ClientTest.cs @@ -0,0 +1,85 @@ +using System; +using System.IO; +using System.Net.Sockets; +using System.Text; + +using NUnit.Framework; + +using Org.BouncyCastle.Utilities.Date; + +namespace Org.BouncyCastle.Tls.Tests +{ + [TestFixture] + public class PskTls13ClientTest + { + [Test, Ignore] + public void TestConnection() + { + string host = "localhost"; + int port = 5556; + + long time0 = DateTimeUtilities.CurrentUnixMs(); + + MockPskTls13Client client = new MockPskTls13Client(); + TlsClientProtocol protocol = OpenTlsClientConnection(host, port, client); + + long time1 = DateTimeUtilities.CurrentUnixMs(); + Console.WriteLine("Elapsed: " + (time1 - time0) + "ms"); + + Http11Get(host, port, protocol.Stream); + + protocol.Close(); + } + + private static void Http11Get(string host, int port, Stream s) + { + WriteUtf8Line(s, "GET / HTTP/1.1"); + //WriteUtf8Line(s, "Host: " + host + ":" + port); + WriteUtf8Line(s, ""); + s.Flush(); + + Console.WriteLine("---"); + + string[] ends = new string[] { "", "HTTP/1.1 3", "HTTP/1.1 4" }; + + StreamReader reader = new StreamReader(s); + + bool finished = false; + string line; + while (!finished && (line = reader.ReadLine()) != null) + { + Console.WriteLine("<<< " + line); + + string upperLine = TlsTestUtilities.ToUpperInvariant(line); + + // TEST CODE ONLY. This is not a robust way of parsing the result! + foreach (string end in ends) + { + if (upperLine.IndexOf(end) >= 0) + { + finished = true; + break; + } + } + } + + Console.Out.Flush(); + } + + private static TlsClientProtocol OpenTlsClientConnection(string hostname, int port, TlsClient client) + { + TcpClient tcp = new TcpClient(hostname, port); + + TlsClientProtocol protocol = new TlsClientProtocol(tcp.GetStream()); + protocol.Connect(client); + return protocol; + } + + private static void WriteUtf8Line(Stream output, string line) + { + byte[] buf = Encoding.UTF8.GetBytes(line + "\r\n"); + output.Write(buf, 0, buf.Length); + Console.WriteLine(">>> " + line); + } + } +} -- cgit 1.4.1