diff --git a/crypto/test/UnitTests.csproj b/crypto/test/UnitTests.csproj
index c5fdecd54..72d9e6320 100644
--- a/crypto/test/UnitTests.csproj
+++ b/crypto/test/UnitTests.csproj
@@ -482,6 +482,7 @@
<Compile Include="src\tls\test\MockDtlsServer.cs" />
<Compile Include="src\tls\test\MockPskDtlsClient.cs" />
<Compile Include="src\tls\test\MockPskDtlsServer.cs" />
+ <Compile Include="src\tls\test\MockPskTls13Client.cs" />
<Compile Include="src\tls\test\MockPskTlsClient.cs" />
<Compile Include="src\tls\test\MockPskTlsServer.cs" />
<Compile Include="src\tls\test\MockSrpTlsClient.cs" />
@@ -491,6 +492,7 @@
<Compile Include="src\tls\test\NetworkStream.cs" />
<Compile Include="src\tls\test\PipedStream.cs" />
<Compile Include="src\tls\test\PrfTest.cs" />
+ <Compile Include="src\tls\test\PskTls13ClientTest.cs" />
<Compile Include="src\tls\test\PskTlsClientTest.cs" />
<Compile Include="src\tls\test\PskTlsServerTest.cs" />
<Compile Include="src\tls\test\TlsClientTest.cs" />
diff --git a/crypto/test/src/tls/test/MockPskTls13Client.cs b/crypto/test/src/tls/test/MockPskTls13Client.cs
new file mode 100644
index 000000000..d8be1fddd
--- /dev/null
+++ b/crypto/test/src/tls/test/MockPskTls13Client.cs
@@ -0,0 +1,110 @@
+using System;
+using System.Collections;
+using System.IO;
+
+using Org.BouncyCastle.Tls.Crypto;
+using Org.BouncyCastle.Tls.Crypto.Impl.BC;
+using Org.BouncyCastle.Security;
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+ internal class MockPskTls13Client
+ : AbstractTlsClient
+ {
+ internal MockPskTls13Client()
+ : base(new BcTlsCrypto(new SecureRandom()))
+ {
+ }
+
+ //public override IList GetEarlyKeyShareGroups()
+ //{
+ // return TlsUtilities.VectorOfOne(NamedGroup.secp256r1);
+ // //return null;
+ //}
+
+ //public override short[] GetPskKeyExchangeModes()
+ //{
+ // return new short[] { PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke };
+ //}
+
+ protected override IList GetProtocolNames()
+ {
+ IList protocolNames = new ArrayList();
+ protocolNames.Add(ProtocolName.Http_1_1);
+ protocolNames.Add(ProtocolName.Http_2_Tls);
+ return protocolNames;
+ }
+
+ protected override int[] GetSupportedCipherSuites()
+ {
+ return TlsUtilities.GetSupportedCipherSuites(Crypto, new int[] { CipherSuite.TLS_AES_128_GCM_SHA256 });
+ }
+
+ protected override ProtocolVersion[] GetSupportedVersions()
+ {
+ return ProtocolVersion.TLSv13.Only();
+ }
+
+ public override IList GetExternalPsks()
+ {
+ byte[] identity = Strings.ToUtf8ByteArray("client");
+ TlsSecret key = Crypto.CreateSecret(Strings.ToUtf8ByteArray("TLS_TEST_PSK"));
+ int prfAlgorithm = PrfAlgorithm.tls13_hkdf_sha256;
+
+ return TlsUtilities.VectorOfOne(new BasicTlsPskExternal(identity, key, prfAlgorithm));
+ }
+
+ public override void NotifyAlertRaised(short alertLevel, short alertDescription, string message,
+ Exception cause)
+ {
+ TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+ output.WriteLine("TLS 1.3 PSK client raised alert: " + AlertLevel.GetText(alertLevel)
+ + ", " + AlertDescription.GetText(alertDescription));
+ if (message != null)
+ {
+ output.WriteLine("> " + message);
+ }
+ if (cause != null)
+ {
+ output.WriteLine(cause);
+ }
+ }
+
+ public override void NotifyAlertReceived(short alertLevel, short alertDescription)
+ {
+ TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+ output.WriteLine("TLS 1.3 PSK client received alert: " + AlertLevel.GetText(alertLevel)
+ + ", " + AlertDescription.GetText(alertDescription));
+ }
+
+ public override void NotifySelectedPsk(TlsPsk selectedPsk)
+ {
+ if (null == selectedPsk)
+ throw new TlsFatalAlert(AlertDescription.handshake_failure);
+ }
+
+ public override void NotifyServerVersion(ProtocolVersion serverVersion)
+ {
+ base.NotifyServerVersion(serverVersion);
+
+ Console.WriteLine("TLS 1.3 PSK client negotiated " + serverVersion);
+ }
+
+ public override TlsAuthentication GetAuthentication()
+ {
+ throw new TlsFatalAlert(AlertDescription.internal_error);
+ }
+
+ public override void NotifyHandshakeComplete()
+ {
+ base.NotifyHandshakeComplete();
+
+ ProtocolName protocolName = m_context.SecurityParameters.ApplicationProtocol;
+ if (protocolName != null)
+ {
+ Console.WriteLine("Client ALPN: " + protocolName.GetUtf8Decoding());
+ }
+ }
+ }
+}
diff --git a/crypto/test/src/tls/test/PskTls13ClientTest.cs b/crypto/test/src/tls/test/PskTls13ClientTest.cs
new file mode 100644
index 000000000..6f67b0572
--- /dev/null
+++ b/crypto/test/src/tls/test/PskTls13ClientTest.cs
@@ -0,0 +1,85 @@
+using System;
+using System.IO;
+using System.Net.Sockets;
+using System.Text;
+
+using NUnit.Framework;
+
+using Org.BouncyCastle.Utilities.Date;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+ [TestFixture]
+ public class PskTls13ClientTest
+ {
+ [Test, Ignore]
+ public void TestConnection()
+ {
+ string host = "localhost";
+ int port = 5556;
+
+ long time0 = DateTimeUtilities.CurrentUnixMs();
+
+ MockPskTls13Client client = new MockPskTls13Client();
+ TlsClientProtocol protocol = OpenTlsClientConnection(host, port, client);
+
+ long time1 = DateTimeUtilities.CurrentUnixMs();
+ Console.WriteLine("Elapsed: " + (time1 - time0) + "ms");
+
+ Http11Get(host, port, protocol.Stream);
+
+ protocol.Close();
+ }
+
+ private static void Http11Get(string host, int port, Stream s)
+ {
+ WriteUtf8Line(s, "GET / HTTP/1.1");
+ //WriteUtf8Line(s, "Host: " + host + ":" + port);
+ WriteUtf8Line(s, "");
+ s.Flush();
+
+ Console.WriteLine("---");
+
+ string[] ends = new string[] { "</HTML>", "HTTP/1.1 3", "HTTP/1.1 4" };
+
+ StreamReader reader = new StreamReader(s);
+
+ bool finished = false;
+ string line;
+ while (!finished && (line = reader.ReadLine()) != null)
+ {
+ Console.WriteLine("<<< " + line);
+
+ string upperLine = TlsTestUtilities.ToUpperInvariant(line);
+
+ // TEST CODE ONLY. This is not a robust way of parsing the result!
+ foreach (string end in ends)
+ {
+ if (upperLine.IndexOf(end) >= 0)
+ {
+ finished = true;
+ break;
+ }
+ }
+ }
+
+ Console.Out.Flush();
+ }
+
+ private static TlsClientProtocol OpenTlsClientConnection(string hostname, int port, TlsClient client)
+ {
+ TcpClient tcp = new TcpClient(hostname, port);
+
+ TlsClientProtocol protocol = new TlsClientProtocol(tcp.GetStream());
+ protocol.Connect(client);
+ return protocol;
+ }
+
+ private static void WriteUtf8Line(Stream output, string line)
+ {
+ byte[] buf = Encoding.UTF8.GetBytes(line + "\r\n");
+ output.Write(buf, 0, buf.Length);
+ Console.WriteLine(">>> " + line);
+ }
+ }
+}
|