diff --git a/crypto/test/UnitTests.csproj b/crypto/test/UnitTests.csproj
index 1650a05fa..1945f1367 100644
--- a/crypto/test/UnitTests.csproj
+++ b/crypto/test/UnitTests.csproj
@@ -491,6 +491,7 @@
<Compile Include="src\tls\test\MockPskDtlsClient.cs" />
<Compile Include="src\tls\test\MockPskDtlsServer.cs" />
<Compile Include="src\tls\test\MockPskTls13Client.cs" />
+ <Compile Include="src\tls\test\MockPskTls13Server.cs" />
<Compile Include="src\tls\test\MockPskTlsClient.cs" />
<Compile Include="src\tls\test\MockPskTlsServer.cs" />
<Compile Include="src\tls\test\MockSrpTlsClient.cs" />
@@ -501,8 +502,10 @@
<Compile Include="src\tls\test\PipedStream.cs" />
<Compile Include="src\tls\test\PrfTest.cs" />
<Compile Include="src\tls\test\PskTls13ClientTest.cs" />
+ <Compile Include="src\tls\test\PskTls13ServerTest.cs" />
<Compile Include="src\tls\test\PskTlsClientTest.cs" />
<Compile Include="src\tls\test\PskTlsServerTest.cs" />
+ <Compile Include="src\tls\test\Tls13PskProtocolTest.cs" />
<Compile Include="src\tls\test\TlsClientTest.cs" />
<Compile Include="src\tls\test\TlsProtocolNonBlockingTest.cs" />
<Compile Include="src\tls\test\TlsProtocolTest.cs" />
diff --git a/crypto/test/src/tls/test/MockPskTls13Server.cs b/crypto/test/src/tls/test/MockPskTls13Server.cs
new file mode 100644
index 000000000..d1ea69b95
--- /dev/null
+++ b/crypto/test/src/tls/test/MockPskTls13Server.cs
@@ -0,0 +1,108 @@
+using System;
+using System.Collections;
+using System.IO;
+
+using Org.BouncyCastle.Tls.Crypto;
+using Org.BouncyCastle.Tls.Crypto.Impl.BC;
+using Org.BouncyCastle.Security;
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+ internal class MockPskTls13Server
+ : AbstractTlsServer
+ {
+ internal MockPskTls13Server()
+ : base(new BcTlsCrypto(new SecureRandom()))
+ {
+ }
+
+ public override TlsCredentials GetCredentials()
+ {
+ return null;
+ }
+
+ protected override IList GetProtocolNames()
+ {
+ IList protocolNames = new ArrayList();
+ protocolNames.Add(ProtocolName.Http_2_Tls);
+ protocolNames.Add(ProtocolName.Http_1_1);
+ return protocolNames;
+ }
+
+ protected override int[] GetSupportedCipherSuites()
+ {
+ return TlsUtilities.GetSupportedCipherSuites(Crypto,
+ new int[] { CipherSuite.TLS_AES_128_CCM_8_SHA256, CipherSuite.TLS_AES_128_CCM_SHA256,
+ CipherSuite.TLS_AES_128_GCM_SHA256, CipherSuite.TLS_CHACHA20_POLY1305_SHA256 });
+ }
+
+ protected override ProtocolVersion[] GetSupportedVersions()
+ {
+ return ProtocolVersion.TLSv13.Only();
+ }
+
+ public override ProtocolVersion GetServerVersion()
+ {
+ ProtocolVersion serverVersion = base.GetServerVersion();
+
+ Console.WriteLine("TLS 1.3 PSK server negotiated " + serverVersion);
+
+ return serverVersion;
+ }
+
+ public override TlsPskExternal GetExternalPsk(IList identities)
+ {
+ byte[] identity = Strings.ToUtf8ByteArray("client");
+ long obfuscatedTicketAge = 0L;
+
+ PskIdentity matchIdentity = new PskIdentity(identity, obfuscatedTicketAge);
+
+ for (int i = 0, count = identities.Count; i < count; ++i)
+ {
+ if (matchIdentity.Equals(identities[i]))
+ {
+ TlsSecret key = Crypto.CreateSecret(Strings.ToUtf8ByteArray("TLS_TEST_PSK"));
+ int prfAlgorithm = PrfAlgorithm.tls13_hkdf_sha256;
+
+ return new BasicTlsPskExternal(identity, key, prfAlgorithm);
+ }
+ }
+ return null;
+ }
+
+ public override void NotifyAlertRaised(short alertLevel, short alertDescription, string message,
+ Exception cause)
+ {
+ TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+ output.WriteLine("TLS 1.3 PSK server raised alert: " + AlertLevel.GetText(alertLevel)
+ + ", " + AlertDescription.GetText(alertDescription));
+ if (message != null)
+ {
+ output.WriteLine("> " + message);
+ }
+ if (cause != null)
+ {
+ output.WriteLine(cause);
+ }
+ }
+
+ public override void NotifyAlertReceived(short alertLevel, short alertDescription)
+ {
+ TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+ output.WriteLine("TLS 1.3 PSK server received alert: " + AlertLevel.GetText(alertLevel)
+ + ", " + AlertDescription.GetText(alertDescription));
+ }
+
+ public override void NotifyHandshakeComplete()
+ {
+ base.NotifyHandshakeComplete();
+
+ ProtocolName protocolName = m_context.SecurityParameters.ApplicationProtocol;
+ if (protocolName != null)
+ {
+ Console.WriteLine("Server ALPN: " + protocolName.GetUtf8Decoding());
+ }
+ }
+ }
+}
diff --git a/crypto/test/src/tls/test/PskTls13ServerTest.cs b/crypto/test/src/tls/test/PskTls13ServerTest.cs
new file mode 100644
index 000000000..4a924b81d
--- /dev/null
+++ b/crypto/test/src/tls/test/PskTls13ServerTest.cs
@@ -0,0 +1,77 @@
+using System;
+using System.IO;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+using NUnit.Framework;
+
+using Org.BouncyCastle.Utilities.IO;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+ [TestFixture]
+ public class PskTls13ServerTest
+ {
+ [Test, Ignore]
+ public void TestConnection()
+ {
+ int port = 5556;
+
+ TcpListener ss = new TcpListener(IPAddress.Any, port);
+ ss.Start();
+ Stream stdout = Console.OpenStandardOutput();
+ try
+ {
+ while (true)
+ {
+ TcpClient s = ss.AcceptTcpClient();
+ Console.WriteLine("--------------------------------------------------------------------------------");
+ Console.WriteLine("Accepted " + s);
+ Server serverRun = new Server(s, stdout);
+ Thread t = new Thread(new ThreadStart(serverRun.Run));
+ t.Start();
+ }
+ }
+ finally
+ {
+ ss.Stop();
+ }
+ }
+
+ internal class Server
+ {
+ private readonly TcpClient s;
+ private readonly Stream stdout;
+
+ internal Server(TcpClient s, Stream stdout)
+ {
+ this.s = s;
+ this.stdout = stdout;
+ }
+
+ public void Run()
+ {
+ try
+ {
+ MockPskTls13Server server = new MockPskTls13Server();
+ TlsServerProtocol serverProtocol = new TlsServerProtocol(s.GetStream());
+ serverProtocol.Accept(server);
+ Stream log = new TeeOutputStream(serverProtocol.Stream, stdout);
+ Streams.PipeAll(serverProtocol.Stream, log);
+ serverProtocol.Close();
+ }
+ finally
+ {
+ try
+ {
+ s.Close();
+ }
+ catch (IOException)
+ {
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/crypto/test/src/tls/test/Tls13PskProtocolTest.cs b/crypto/test/src/tls/test/Tls13PskProtocolTest.cs
new file mode 100644
index 000000000..b66e781a2
--- /dev/null
+++ b/crypto/test/src/tls/test/Tls13PskProtocolTest.cs
@@ -0,0 +1,75 @@
+using System;
+using System.IO;
+using System.Threading;
+
+using NUnit.Framework;
+
+using Org.BouncyCastle.Utilities;
+using Org.BouncyCastle.Utilities.IO;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+ [TestFixture]
+ public class Tls13PskProtocolTest
+ {
+ [Test]
+ public void TestClientServer()
+ {
+ PipedStream clientPipe = new PipedStream();
+ PipedStream serverPipe = new PipedStream(clientPipe);
+
+ TlsClientProtocol clientProtocol = new TlsClientProtocol(clientPipe);
+ TlsServerProtocol serverProtocol = new TlsServerProtocol(serverPipe);
+
+ MockPskTls13Client client = new MockPskTls13Client();
+ MockPskTls13Server server = new MockPskTls13Server();
+
+ Server serverRun = new Server(serverProtocol, server);
+ Thread serverThread = new Thread(new ThreadStart(serverRun.Run));
+ serverThread.Start();
+
+ clientProtocol.Connect(client);
+
+ byte[] data = new byte[1000];
+ client.Crypto.SecureRandom.NextBytes(data);
+
+ Stream output = clientProtocol.Stream;
+ output.Write(data, 0, data.Length);
+
+ byte[] echo = new byte[data.Length];
+ int count = Streams.ReadFully(clientProtocol.Stream, echo);
+
+ Assert.AreEqual(count, data.Length);
+ Assert.IsTrue(Arrays.AreEqual(data, echo));
+
+ output.Close();
+
+ serverThread.Join();
+ }
+
+ internal class Server
+ {
+ private readonly TlsServerProtocol m_serverProtocol;
+ private readonly TlsServer m_server;
+
+ internal Server(TlsServerProtocol serverProtocol, TlsServer server)
+ {
+ this.m_serverProtocol = serverProtocol;
+ this.m_server = server;
+ }
+
+ public void Run()
+ {
+ try
+ {
+ m_serverProtocol.Accept(m_server);
+ Streams.PipeAll(m_serverProtocol.Stream, m_serverProtocol.Stream);
+ m_serverProtocol.Close();
+ }
+ catch (Exception)
+ {
+ }
+ }
+ }
+ }
+}
|