diff --git a/crypto/crypto.csproj b/crypto/crypto.csproj
index e06b37f9f..5781f7711 100644
--- a/crypto/crypto.csproj
+++ b/crypto/crypto.csproj
@@ -14599,6 +14599,11 @@
BuildAction = "Compile"
/>
<File
+ RelPath = "test\src\crypto\tls\test\MockPskTls13Server.cs"
+ SubType = "Code"
+ BuildAction = "Compile"
+ />
+ <File
RelPath = "test\src\crypto\tls\test\MockPskTlsClient.cs"
SubType = "Code"
BuildAction = "Compile"
@@ -14644,6 +14649,11 @@
BuildAction = "Compile"
/>
<File
+ RelPath = "test\src\crypto\tls\test\PskTls13ServerTest.cs"
+ SubType = "Code"
+ BuildAction = "Compile"
+ />
+ <File
RelPath = "test\src\crypto\tls\test\PskTlsClientTest.cs"
SubType = "Code"
BuildAction = "Compile"
@@ -14654,6 +14664,11 @@
BuildAction = "Compile"
/>
<File
+ RelPath = "test\src\crypto\tls\test\Tls13PskProtocolTest.cs"
+ SubType = "Code"
+ BuildAction = "Compile"
+ />
+ <File
RelPath = "test\src\crypto\tls\test\TlsClientTest.cs"
SubType = "Code"
BuildAction = "Compile"
diff --git a/crypto/src/tls/TlsServerProtocol.cs b/crypto/src/tls/TlsServerProtocol.cs
index e14fb7d70..40218a2fb 100644
--- a/crypto/src/tls/TlsServerProtocol.cs
+++ b/crypto/src/tls/TlsServerProtocol.cs
@@ -121,7 +121,8 @@ namespace Org.BouncyCastle.Tls
}
/// <exception cref="IOException"/>
- protected virtual ServerHello Generate13ServerHello(ClientHello clientHello, bool afterHelloRetryRequest)
+ protected virtual ServerHello Generate13ServerHello(ClientHello clientHello,
+ HandshakeMessageInput clientHelloMessage, bool afterHelloRetryRequest)
{
SecurityParameters securityParameters = m_tlsServerContext.SecurityParameters;
@@ -136,6 +137,10 @@ namespace Org.BouncyCastle.Tls
ProtocolVersion serverVersion = securityParameters.NegotiatedVersion;
TlsCrypto crypto = m_tlsServerContext.Crypto;
+ // NOTE: Will only select for psk_dhe_ke
+ OfferedPsks.SelectedConfig selectedPsk = TlsUtilities.SelectPreSharedKey(m_tlsServerContext, m_tlsServer,
+ clientHelloExtensions, clientHelloMessage, m_handshakeHash, afterHelloRetryRequest);
+
IList clientShares = TlsExtensionsUtilities.GetKeyShareClientHello(clientHelloExtensions);
KeyShareEntry clientShare = null;
@@ -144,6 +149,23 @@ namespace Org.BouncyCastle.Tls
if (m_retryGroup < 0)
throw new TlsFatalAlert(AlertDescription.internal_error);
+ if (null == selectedPsk)
+ {
+ /*
+ * RFC 8446 4.2.3. If a server is authenticating via a certificate and the client has
+ * not sent a "signature_algorithms" extension, then the server MUST abort the handshake
+ * with a "missing_extension" alert.
+ */
+ if (null == securityParameters.ClientSigAlgs)
+ throw new TlsFatalAlert(AlertDescription.missing_extension);
+ }
+ else
+ {
+ // TODO[tls13] Maybe filter the offered PSKs by PRF algorithm before server selection instead
+ if (selectedPsk.m_psk.PrfAlgorithm != securityParameters.PrfAlgorithm)
+ throw new TlsFatalAlert(AlertDescription.illegal_parameter);
+ }
+
/*
* TODO[tls13] Confirm fields in the ClientHello haven't changed
*
@@ -182,8 +204,7 @@ namespace Org.BouncyCastle.Tls
* not sent a "signature_algorithms" extension, then the server MUST abort the handshake
* with a "missing_extension" alert.
*/
- // TODO[tls13] Revisit this check if we add support for PSK-only key exchange.
- if (null == securityParameters.ClientSigAlgs)
+ if (null == selectedPsk && null == securityParameters.ClientSigAlgs)
throw new TlsFatalAlert(AlertDescription.missing_extension);
m_tlsServer.ProcessClientExtensions(clientHelloExtensions);
@@ -218,6 +239,7 @@ namespace Org.BouncyCastle.Tls
}
{
+ // TODO[tls13] Constrain selection when PSK selected
int cipherSuite = m_tlsServer.GetSelectedCipherSuite();
if (!TlsUtilities.IsValidCipherSuiteSelection(m_offeredCipherSuites, cipherSuite) ||
@@ -309,11 +331,17 @@ namespace Org.BouncyCastle.Tls
this.m_expectSessionTicket = false;
- // TODO[tls13-psk] Use PSK early secret if negotiated
TlsSecret pskEarlySecret = null;
+ if (null != selectedPsk)
+ {
+ pskEarlySecret = selectedPsk.m_earlySecret;
- TlsSecret sharedSecret = null;
+ this.m_selectedPsk13 = true;
+ TlsExtensionsUtilities.AddPreSharedKeyServerHello(serverHelloExtensions, selectedPsk.m_index);
+ }
+
+ TlsSecret sharedSecret;
{
int namedGroup = clientShare.NamedGroup;
@@ -353,7 +381,8 @@ namespace Org.BouncyCastle.Tls
}
/// <exception cref="IOException"/>
- protected virtual ServerHello GenerateServerHello(ClientHello clientHello)
+ protected virtual ServerHello GenerateServerHello(ClientHello clientHello,
+ HandshakeMessageInput clientHelloMessage)
{
ProtocolVersion clientLegacyVersion = clientHello.Version;
if (!clientLegacyVersion.IsTls)
@@ -426,7 +455,7 @@ namespace Org.BouncyCastle.Tls
m_recordStream.SetWriteVersion(ProtocolVersion.TLSv12);
- return Generate13ServerHello(clientHello, false);
+ return Generate13ServerHello(clientHello, clientHelloMessage, false);
}
m_recordStream.SetWriteVersion(serverVersion);
@@ -747,10 +776,9 @@ namespace Org.BouncyCastle.Tls
case CS_SERVER_HELLO_RETRY_REQUEST:
{
ClientHello clientHelloRetry = ReceiveClientHelloMessage(buf);
- buf.UpdateHash(m_handshakeHash);
this.m_connectionState = CS_CLIENT_HELLO_RETRY;
- ServerHello serverHello = Generate13ServerHello(clientHelloRetry, true);
+ ServerHello serverHello = Generate13ServerHello(clientHelloRetry, buf, true);
SendServerHelloMessage(serverHello);
this.m_connectionState = CS_SERVER_HELLO;
@@ -866,10 +894,9 @@ namespace Org.BouncyCastle.Tls
case CS_START:
{
ClientHello clientHello = ReceiveClientHelloMessage(buf);
- buf.UpdateHash(m_handshakeHash);
this.m_connectionState = CS_CLIENT_HELLO;
- ServerHello serverHello = GenerateServerHello(clientHello);
+ ServerHello serverHello = GenerateServerHello(clientHello, buf);
m_handshakeHash.NotifyPrfDetermined();
if (TlsUtilities.IsTlsV13(securityParameters.NegotiatedVersion))
@@ -898,6 +925,9 @@ namespace Org.BouncyCastle.Tls
break;
}
+ // For TLS 1.3+, this was already done by GenerateServerHello
+ buf.UpdateHash(m_handshakeHash);
+
SendServerHelloMessage(serverHello);
this.m_connectionState = CS_SERVER_HELLO;
@@ -1042,9 +1072,6 @@ namespace Org.BouncyCastle.Tls
m_tlsServer.ProcessClientSupplementalData(null);
}
- if (m_certificateRequest == null)
- throw new TlsFatalAlert(AlertDescription.unexpected_message);
-
ReceiveCertificateMessage(buf);
this.m_connectionState = CS_CLIENT_CERTIFICATE;
break;
@@ -1232,6 +1259,9 @@ namespace Org.BouncyCastle.Tls
{
// TODO[tls13] This currently just duplicates 'receiveCertificateMessage'
+ if (null == m_certificateRequest)
+ throw new TlsFatalAlert(AlertDescription.unexpected_message);
+
Certificate.ParseOptions options = new Certificate.ParseOptions()
.SetMaxChainLength(m_tlsServer.GetMaxCertificateChainLength());
@@ -1267,6 +1297,9 @@ namespace Org.BouncyCastle.Tls
/// <exception cref="IOException"/>
protected virtual void ReceiveCertificateMessage(MemoryStream buf)
{
+ if (null == m_certificateRequest)
+ throw new TlsFatalAlert(AlertDescription.unexpected_message);
+
Certificate.ParseOptions options = new Certificate.ParseOptions()
.SetMaxChainLength(m_tlsServer.GetMaxCertificateChainLength());
@@ -1353,52 +1386,57 @@ namespace Org.BouncyCastle.Tls
Send13EncryptedExtensionsMessage(m_serverExtensions);
this.m_connectionState = CS_SERVER_ENCRYPTED_EXTENSIONS;
- // CertificateRequest
+ if (m_selectedPsk13)
+ {
+ /*
+ * For PSK-only key exchange, there's no CertificateRequest, Certificate, CertificateVerify.
+ */
+ }
+ else
{
- this.m_certificateRequest = m_tlsServer.GetCertificateRequest();
- if (null != m_certificateRequest)
+ // CertificateRequest
{
- if (!m_certificateRequest.HasCertificateRequestContext(TlsUtilities.EmptyBytes))
- throw new TlsFatalAlert(AlertDescription.internal_error);
+ this.m_certificateRequest = m_tlsServer.GetCertificateRequest();
+ if (null != m_certificateRequest)
+ {
+ if (!m_certificateRequest.HasCertificateRequestContext(TlsUtilities.EmptyBytes))
+ throw new TlsFatalAlert(AlertDescription.internal_error);
- TlsUtilities.EstablishServerSigAlgs(securityParameters, m_certificateRequest);
+ TlsUtilities.EstablishServerSigAlgs(securityParameters, m_certificateRequest);
- SendCertificateRequestMessage(m_certificateRequest);
- this.m_connectionState = CS_SERVER_CERTIFICATE_REQUEST;
+ SendCertificateRequestMessage(m_certificateRequest);
+ this.m_connectionState = CS_SERVER_CERTIFICATE_REQUEST;
+ }
}
- }
-
- /*
- * TODO[tls13] For PSK-only key exchange, there's no Certificate message.
- */
- TlsCredentialedSigner serverCredentials = TlsUtilities.Establish13ServerCredentials(m_tlsServer);
- if (null == serverCredentials)
- throw new TlsFatalAlert(AlertDescription.internal_error);
+ TlsCredentialedSigner serverCredentials = TlsUtilities.Establish13ServerCredentials(m_tlsServer);
+ if (null == serverCredentials)
+ throw new TlsFatalAlert(AlertDescription.internal_error);
- // Certificate
- {
- /*
- * TODO[tls13] Note that we are expecting the TlsServer implementation to take care of
- * e.g. adding optional "status_request" extension to each CertificateEntry.
- */
- /*
- * No CertificateStatus message is sent; TLS 1.3 uses per-CertificateEntry
- * "status_request" extension instead.
- */
+ // Certificate
+ {
+ /*
+ * TODO[tls13] Note that we are expecting the TlsServer implementation to take care of
+ * e.g. adding optional "status_request" extension to each CertificateEntry.
+ */
+ /*
+ * No CertificateStatus message is sent; TLS 1.3 uses per-CertificateEntry
+ * "status_request" extension instead.
+ */
- Certificate serverCertificate = serverCredentials.Certificate;
- Send13CertificateMessage(serverCertificate);
- securityParameters.m_tlsServerEndPoint = null;
- this.m_connectionState = CS_SERVER_CERTIFICATE;
- }
+ Certificate serverCertificate = serverCredentials.Certificate;
+ Send13CertificateMessage(serverCertificate);
+ securityParameters.m_tlsServerEndPoint = null;
+ this.m_connectionState = CS_SERVER_CERTIFICATE;
+ }
- // CertificateVerify
- {
- DigitallySigned certificateVerify = TlsUtilities.Generate13CertificateVerify(m_tlsServerContext,
- serverCredentials, m_handshakeHash);
- Send13CertificateVerifyMessage(certificateVerify);
- this.m_connectionState = CS_CLIENT_CERTIFICATE_VERIFY;
+ // CertificateVerify
+ {
+ DigitallySigned certificateVerify = TlsUtilities.Generate13CertificateVerify(m_tlsServerContext,
+ serverCredentials, m_handshakeHash);
+ Send13CertificateVerifyMessage(certificateVerify);
+ this.m_connectionState = CS_CLIENT_CERTIFICATE_VERIFY;
+ }
}
// Finished
diff --git a/crypto/test/UnitTests.csproj b/crypto/test/UnitTests.csproj
index 1650a05fa..1945f1367 100644
--- a/crypto/test/UnitTests.csproj
+++ b/crypto/test/UnitTests.csproj
@@ -491,6 +491,7 @@
<Compile Include="src\tls\test\MockPskDtlsClient.cs" />
<Compile Include="src\tls\test\MockPskDtlsServer.cs" />
<Compile Include="src\tls\test\MockPskTls13Client.cs" />
+ <Compile Include="src\tls\test\MockPskTls13Server.cs" />
<Compile Include="src\tls\test\MockPskTlsClient.cs" />
<Compile Include="src\tls\test\MockPskTlsServer.cs" />
<Compile Include="src\tls\test\MockSrpTlsClient.cs" />
@@ -501,8 +502,10 @@
<Compile Include="src\tls\test\PipedStream.cs" />
<Compile Include="src\tls\test\PrfTest.cs" />
<Compile Include="src\tls\test\PskTls13ClientTest.cs" />
+ <Compile Include="src\tls\test\PskTls13ServerTest.cs" />
<Compile Include="src\tls\test\PskTlsClientTest.cs" />
<Compile Include="src\tls\test\PskTlsServerTest.cs" />
+ <Compile Include="src\tls\test\Tls13PskProtocolTest.cs" />
<Compile Include="src\tls\test\TlsClientTest.cs" />
<Compile Include="src\tls\test\TlsProtocolNonBlockingTest.cs" />
<Compile Include="src\tls\test\TlsProtocolTest.cs" />
diff --git a/crypto/test/src/tls/test/MockPskTls13Server.cs b/crypto/test/src/tls/test/MockPskTls13Server.cs
new file mode 100644
index 000000000..d1ea69b95
--- /dev/null
+++ b/crypto/test/src/tls/test/MockPskTls13Server.cs
@@ -0,0 +1,108 @@
+using System;
+using System.Collections;
+using System.IO;
+
+using Org.BouncyCastle.Tls.Crypto;
+using Org.BouncyCastle.Tls.Crypto.Impl.BC;
+using Org.BouncyCastle.Security;
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+ internal class MockPskTls13Server
+ : AbstractTlsServer
+ {
+ internal MockPskTls13Server()
+ : base(new BcTlsCrypto(new SecureRandom()))
+ {
+ }
+
+ public override TlsCredentials GetCredentials()
+ {
+ return null;
+ }
+
+ protected override IList GetProtocolNames()
+ {
+ IList protocolNames = new ArrayList();
+ protocolNames.Add(ProtocolName.Http_2_Tls);
+ protocolNames.Add(ProtocolName.Http_1_1);
+ return protocolNames;
+ }
+
+ protected override int[] GetSupportedCipherSuites()
+ {
+ return TlsUtilities.GetSupportedCipherSuites(Crypto,
+ new int[] { CipherSuite.TLS_AES_128_CCM_8_SHA256, CipherSuite.TLS_AES_128_CCM_SHA256,
+ CipherSuite.TLS_AES_128_GCM_SHA256, CipherSuite.TLS_CHACHA20_POLY1305_SHA256 });
+ }
+
+ protected override ProtocolVersion[] GetSupportedVersions()
+ {
+ return ProtocolVersion.TLSv13.Only();
+ }
+
+ public override ProtocolVersion GetServerVersion()
+ {
+ ProtocolVersion serverVersion = base.GetServerVersion();
+
+ Console.WriteLine("TLS 1.3 PSK server negotiated " + serverVersion);
+
+ return serverVersion;
+ }
+
+ public override TlsPskExternal GetExternalPsk(IList identities)
+ {
+ byte[] identity = Strings.ToUtf8ByteArray("client");
+ long obfuscatedTicketAge = 0L;
+
+ PskIdentity matchIdentity = new PskIdentity(identity, obfuscatedTicketAge);
+
+ for (int i = 0, count = identities.Count; i < count; ++i)
+ {
+ if (matchIdentity.Equals(identities[i]))
+ {
+ TlsSecret key = Crypto.CreateSecret(Strings.ToUtf8ByteArray("TLS_TEST_PSK"));
+ int prfAlgorithm = PrfAlgorithm.tls13_hkdf_sha256;
+
+ return new BasicTlsPskExternal(identity, key, prfAlgorithm);
+ }
+ }
+ return null;
+ }
+
+ public override void NotifyAlertRaised(short alertLevel, short alertDescription, string message,
+ Exception cause)
+ {
+ TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+ output.WriteLine("TLS 1.3 PSK server raised alert: " + AlertLevel.GetText(alertLevel)
+ + ", " + AlertDescription.GetText(alertDescription));
+ if (message != null)
+ {
+ output.WriteLine("> " + message);
+ }
+ if (cause != null)
+ {
+ output.WriteLine(cause);
+ }
+ }
+
+ public override void NotifyAlertReceived(short alertLevel, short alertDescription)
+ {
+ TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out;
+ output.WriteLine("TLS 1.3 PSK server received alert: " + AlertLevel.GetText(alertLevel)
+ + ", " + AlertDescription.GetText(alertDescription));
+ }
+
+ public override void NotifyHandshakeComplete()
+ {
+ base.NotifyHandshakeComplete();
+
+ ProtocolName protocolName = m_context.SecurityParameters.ApplicationProtocol;
+ if (protocolName != null)
+ {
+ Console.WriteLine("Server ALPN: " + protocolName.GetUtf8Decoding());
+ }
+ }
+ }
+}
diff --git a/crypto/test/src/tls/test/PskTls13ServerTest.cs b/crypto/test/src/tls/test/PskTls13ServerTest.cs
new file mode 100644
index 000000000..4a924b81d
--- /dev/null
+++ b/crypto/test/src/tls/test/PskTls13ServerTest.cs
@@ -0,0 +1,77 @@
+using System;
+using System.IO;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+using NUnit.Framework;
+
+using Org.BouncyCastle.Utilities.IO;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+ [TestFixture]
+ public class PskTls13ServerTest
+ {
+ [Test, Ignore]
+ public void TestConnection()
+ {
+ int port = 5556;
+
+ TcpListener ss = new TcpListener(IPAddress.Any, port);
+ ss.Start();
+ Stream stdout = Console.OpenStandardOutput();
+ try
+ {
+ while (true)
+ {
+ TcpClient s = ss.AcceptTcpClient();
+ Console.WriteLine("--------------------------------------------------------------------------------");
+ Console.WriteLine("Accepted " + s);
+ Server serverRun = new Server(s, stdout);
+ Thread t = new Thread(new ThreadStart(serverRun.Run));
+ t.Start();
+ }
+ }
+ finally
+ {
+ ss.Stop();
+ }
+ }
+
+ internal class Server
+ {
+ private readonly TcpClient s;
+ private readonly Stream stdout;
+
+ internal Server(TcpClient s, Stream stdout)
+ {
+ this.s = s;
+ this.stdout = stdout;
+ }
+
+ public void Run()
+ {
+ try
+ {
+ MockPskTls13Server server = new MockPskTls13Server();
+ TlsServerProtocol serverProtocol = new TlsServerProtocol(s.GetStream());
+ serverProtocol.Accept(server);
+ Stream log = new TeeOutputStream(serverProtocol.Stream, stdout);
+ Streams.PipeAll(serverProtocol.Stream, log);
+ serverProtocol.Close();
+ }
+ finally
+ {
+ try
+ {
+ s.Close();
+ }
+ catch (IOException)
+ {
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/crypto/test/src/tls/test/Tls13PskProtocolTest.cs b/crypto/test/src/tls/test/Tls13PskProtocolTest.cs
new file mode 100644
index 000000000..b66e781a2
--- /dev/null
+++ b/crypto/test/src/tls/test/Tls13PskProtocolTest.cs
@@ -0,0 +1,75 @@
+using System;
+using System.IO;
+using System.Threading;
+
+using NUnit.Framework;
+
+using Org.BouncyCastle.Utilities;
+using Org.BouncyCastle.Utilities.IO;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+ [TestFixture]
+ public class Tls13PskProtocolTest
+ {
+ [Test]
+ public void TestClientServer()
+ {
+ PipedStream clientPipe = new PipedStream();
+ PipedStream serverPipe = new PipedStream(clientPipe);
+
+ TlsClientProtocol clientProtocol = new TlsClientProtocol(clientPipe);
+ TlsServerProtocol serverProtocol = new TlsServerProtocol(serverPipe);
+
+ MockPskTls13Client client = new MockPskTls13Client();
+ MockPskTls13Server server = new MockPskTls13Server();
+
+ Server serverRun = new Server(serverProtocol, server);
+ Thread serverThread = new Thread(new ThreadStart(serverRun.Run));
+ serverThread.Start();
+
+ clientProtocol.Connect(client);
+
+ byte[] data = new byte[1000];
+ client.Crypto.SecureRandom.NextBytes(data);
+
+ Stream output = clientProtocol.Stream;
+ output.Write(data, 0, data.Length);
+
+ byte[] echo = new byte[data.Length];
+ int count = Streams.ReadFully(clientProtocol.Stream, echo);
+
+ Assert.AreEqual(count, data.Length);
+ Assert.IsTrue(Arrays.AreEqual(data, echo));
+
+ output.Close();
+
+ serverThread.Join();
+ }
+
+ internal class Server
+ {
+ private readonly TlsServerProtocol m_serverProtocol;
+ private readonly TlsServer m_server;
+
+ internal Server(TlsServerProtocol serverProtocol, TlsServer server)
+ {
+ this.m_serverProtocol = serverProtocol;
+ this.m_server = server;
+ }
+
+ public void Run()
+ {
+ try
+ {
+ m_serverProtocol.Accept(m_server);
+ Streams.PipeAll(m_serverProtocol.Stream, m_serverProtocol.Stream);
+ m_serverProtocol.Close();
+ }
+ catch (Exception)
+ {
+ }
+ }
+ }
+ }
+}
|