diff options
author | Peter Dettman <peter.dettman@bouncycastle.org> | 2021-10-17 22:20:10 +0700 |
---|---|---|
committer | Peter Dettman <peter.dettman@bouncycastle.org> | 2021-10-17 22:20:10 +0700 |
commit | 649a6d3835e1e9f77fda0f13ef67d73ed8e3c044 (patch) | |
tree | 5243740db06347ae5d5eb1426c5ecbd9ba048900 | |
parent | Server-side PSK selection (diff) | |
download | BouncyCastle.NET-ed25519-649a6d3835e1e9f77fda0f13ef67d73ed8e3c044.tar.xz |
Experimental server-side TLS 1.3 PSK
-rw-r--r-- | crypto/crypto.csproj | 15 | ||||
-rw-r--r-- | crypto/src/tls/TlsServerProtocol.cs | 140 | ||||
-rw-r--r-- | crypto/test/UnitTests.csproj | 3 | ||||
-rw-r--r-- | crypto/test/src/tls/test/MockPskTls13Server.cs | 108 | ||||
-rw-r--r-- | crypto/test/src/tls/test/PskTls13ServerTest.cs | 77 | ||||
-rw-r--r-- | crypto/test/src/tls/test/Tls13PskProtocolTest.cs | 75 |
6 files changed, 367 insertions, 51 deletions
diff --git a/crypto/crypto.csproj b/crypto/crypto.csproj index e06b37f9f..5781f7711 100644 --- a/crypto/crypto.csproj +++ b/crypto/crypto.csproj @@ -14599,6 +14599,11 @@ BuildAction = "Compile" /> <File + RelPath = "test\src\crypto\tls\test\MockPskTls13Server.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File RelPath = "test\src\crypto\tls\test\MockPskTlsClient.cs" SubType = "Code" BuildAction = "Compile" @@ -14644,6 +14649,11 @@ BuildAction = "Compile" /> <File + RelPath = "test\src\crypto\tls\test\PskTls13ServerTest.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File RelPath = "test\src\crypto\tls\test\PskTlsClientTest.cs" SubType = "Code" BuildAction = "Compile" @@ -14654,6 +14664,11 @@ BuildAction = "Compile" /> <File + RelPath = "test\src\crypto\tls\test\Tls13PskProtocolTest.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File RelPath = "test\src\crypto\tls\test\TlsClientTest.cs" SubType = "Code" BuildAction = "Compile" diff --git a/crypto/src/tls/TlsServerProtocol.cs b/crypto/src/tls/TlsServerProtocol.cs index e14fb7d70..40218a2fb 100644 --- a/crypto/src/tls/TlsServerProtocol.cs +++ b/crypto/src/tls/TlsServerProtocol.cs @@ -121,7 +121,8 @@ namespace Org.BouncyCastle.Tls } /// <exception cref="IOException"/> - protected virtual ServerHello Generate13ServerHello(ClientHello clientHello, bool afterHelloRetryRequest) + protected virtual ServerHello Generate13ServerHello(ClientHello clientHello, + HandshakeMessageInput clientHelloMessage, bool afterHelloRetryRequest) { SecurityParameters securityParameters = m_tlsServerContext.SecurityParameters; @@ -136,6 +137,10 @@ namespace Org.BouncyCastle.Tls ProtocolVersion serverVersion = securityParameters.NegotiatedVersion; TlsCrypto crypto = m_tlsServerContext.Crypto; + // NOTE: Will only select for psk_dhe_ke + OfferedPsks.SelectedConfig selectedPsk = TlsUtilities.SelectPreSharedKey(m_tlsServerContext, m_tlsServer, + clientHelloExtensions, clientHelloMessage, m_handshakeHash, afterHelloRetryRequest); + IList clientShares = TlsExtensionsUtilities.GetKeyShareClientHello(clientHelloExtensions); KeyShareEntry clientShare = null; @@ -144,6 +149,23 @@ namespace Org.BouncyCastle.Tls if (m_retryGroup < 0) throw new TlsFatalAlert(AlertDescription.internal_error); + if (null == selectedPsk) + { + /* + * RFC 8446 4.2.3. If a server is authenticating via a certificate and the client has + * not sent a "signature_algorithms" extension, then the server MUST abort the handshake + * with a "missing_extension" alert. + */ + if (null == securityParameters.ClientSigAlgs) + throw new TlsFatalAlert(AlertDescription.missing_extension); + } + else + { + // TODO[tls13] Maybe filter the offered PSKs by PRF algorithm before server selection instead + if (selectedPsk.m_psk.PrfAlgorithm != securityParameters.PrfAlgorithm) + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + } + /* * TODO[tls13] Confirm fields in the ClientHello haven't changed * @@ -182,8 +204,7 @@ namespace Org.BouncyCastle.Tls * not sent a "signature_algorithms" extension, then the server MUST abort the handshake * with a "missing_extension" alert. */ - // TODO[tls13] Revisit this check if we add support for PSK-only key exchange. - if (null == securityParameters.ClientSigAlgs) + if (null == selectedPsk && null == securityParameters.ClientSigAlgs) throw new TlsFatalAlert(AlertDescription.missing_extension); m_tlsServer.ProcessClientExtensions(clientHelloExtensions); @@ -218,6 +239,7 @@ namespace Org.BouncyCastle.Tls } { + // TODO[tls13] Constrain selection when PSK selected int cipherSuite = m_tlsServer.GetSelectedCipherSuite(); if (!TlsUtilities.IsValidCipherSuiteSelection(m_offeredCipherSuites, cipherSuite) || @@ -309,11 +331,17 @@ namespace Org.BouncyCastle.Tls this.m_expectSessionTicket = false; - // TODO[tls13-psk] Use PSK early secret if negotiated TlsSecret pskEarlySecret = null; + if (null != selectedPsk) + { + pskEarlySecret = selectedPsk.m_earlySecret; - TlsSecret sharedSecret = null; + this.m_selectedPsk13 = true; + TlsExtensionsUtilities.AddPreSharedKeyServerHello(serverHelloExtensions, selectedPsk.m_index); + } + + TlsSecret sharedSecret; { int namedGroup = clientShare.NamedGroup; @@ -353,7 +381,8 @@ namespace Org.BouncyCastle.Tls } /// <exception cref="IOException"/> - protected virtual ServerHello GenerateServerHello(ClientHello clientHello) + protected virtual ServerHello GenerateServerHello(ClientHello clientHello, + HandshakeMessageInput clientHelloMessage) { ProtocolVersion clientLegacyVersion = clientHello.Version; if (!clientLegacyVersion.IsTls) @@ -426,7 +455,7 @@ namespace Org.BouncyCastle.Tls m_recordStream.SetWriteVersion(ProtocolVersion.TLSv12); - return Generate13ServerHello(clientHello, false); + return Generate13ServerHello(clientHello, clientHelloMessage, false); } m_recordStream.SetWriteVersion(serverVersion); @@ -747,10 +776,9 @@ namespace Org.BouncyCastle.Tls case CS_SERVER_HELLO_RETRY_REQUEST: { ClientHello clientHelloRetry = ReceiveClientHelloMessage(buf); - buf.UpdateHash(m_handshakeHash); this.m_connectionState = CS_CLIENT_HELLO_RETRY; - ServerHello serverHello = Generate13ServerHello(clientHelloRetry, true); + ServerHello serverHello = Generate13ServerHello(clientHelloRetry, buf, true); SendServerHelloMessage(serverHello); this.m_connectionState = CS_SERVER_HELLO; @@ -866,10 +894,9 @@ namespace Org.BouncyCastle.Tls case CS_START: { ClientHello clientHello = ReceiveClientHelloMessage(buf); - buf.UpdateHash(m_handshakeHash); this.m_connectionState = CS_CLIENT_HELLO; - ServerHello serverHello = GenerateServerHello(clientHello); + ServerHello serverHello = GenerateServerHello(clientHello, buf); m_handshakeHash.NotifyPrfDetermined(); if (TlsUtilities.IsTlsV13(securityParameters.NegotiatedVersion)) @@ -898,6 +925,9 @@ namespace Org.BouncyCastle.Tls break; } + // For TLS 1.3+, this was already done by GenerateServerHello + buf.UpdateHash(m_handshakeHash); + SendServerHelloMessage(serverHello); this.m_connectionState = CS_SERVER_HELLO; @@ -1042,9 +1072,6 @@ namespace Org.BouncyCastle.Tls m_tlsServer.ProcessClientSupplementalData(null); } - if (m_certificateRequest == null) - throw new TlsFatalAlert(AlertDescription.unexpected_message); - ReceiveCertificateMessage(buf); this.m_connectionState = CS_CLIENT_CERTIFICATE; break; @@ -1232,6 +1259,9 @@ namespace Org.BouncyCastle.Tls { // TODO[tls13] This currently just duplicates 'receiveCertificateMessage' + if (null == m_certificateRequest) + throw new TlsFatalAlert(AlertDescription.unexpected_message); + Certificate.ParseOptions options = new Certificate.ParseOptions() .SetMaxChainLength(m_tlsServer.GetMaxCertificateChainLength()); @@ -1267,6 +1297,9 @@ namespace Org.BouncyCastle.Tls /// <exception cref="IOException"/> protected virtual void ReceiveCertificateMessage(MemoryStream buf) { + if (null == m_certificateRequest) + throw new TlsFatalAlert(AlertDescription.unexpected_message); + Certificate.ParseOptions options = new Certificate.ParseOptions() .SetMaxChainLength(m_tlsServer.GetMaxCertificateChainLength()); @@ -1353,52 +1386,57 @@ namespace Org.BouncyCastle.Tls Send13EncryptedExtensionsMessage(m_serverExtensions); this.m_connectionState = CS_SERVER_ENCRYPTED_EXTENSIONS; - // CertificateRequest + if (m_selectedPsk13) + { + /* + * For PSK-only key exchange, there's no CertificateRequest, Certificate, CertificateVerify. + */ + } + else { - this.m_certificateRequest = m_tlsServer.GetCertificateRequest(); - if (null != m_certificateRequest) + // CertificateRequest { - if (!m_certificateRequest.HasCertificateRequestContext(TlsUtilities.EmptyBytes)) - throw new TlsFatalAlert(AlertDescription.internal_error); + this.m_certificateRequest = m_tlsServer.GetCertificateRequest(); + if (null != m_certificateRequest) + { + if (!m_certificateRequest.HasCertificateRequestContext(TlsUtilities.EmptyBytes)) + throw new TlsFatalAlert(AlertDescription.internal_error); - TlsUtilities.EstablishServerSigAlgs(securityParameters, m_certificateRequest); + TlsUtilities.EstablishServerSigAlgs(securityParameters, m_certificateRequest); - SendCertificateRequestMessage(m_certificateRequest); - this.m_connectionState = CS_SERVER_CERTIFICATE_REQUEST; + SendCertificateRequestMessage(m_certificateRequest); + this.m_connectionState = CS_SERVER_CERTIFICATE_REQUEST; + } } - } - - /* - * TODO[tls13] For PSK-only key exchange, there's no Certificate message. - */ - TlsCredentialedSigner serverCredentials = TlsUtilities.Establish13ServerCredentials(m_tlsServer); - if (null == serverCredentials) - throw new TlsFatalAlert(AlertDescription.internal_error); + TlsCredentialedSigner serverCredentials = TlsUtilities.Establish13ServerCredentials(m_tlsServer); + if (null == serverCredentials) + throw new TlsFatalAlert(AlertDescription.internal_error); - // Certificate - { - /* - * TODO[tls13] Note that we are expecting the TlsServer implementation to take care of - * e.g. adding optional "status_request" extension to each CertificateEntry. - */ - /* - * No CertificateStatus message is sent; TLS 1.3 uses per-CertificateEntry - * "status_request" extension instead. - */ + // Certificate + { + /* + * TODO[tls13] Note that we are expecting the TlsServer implementation to take care of + * e.g. adding optional "status_request" extension to each CertificateEntry. + */ + /* + * No CertificateStatus message is sent; TLS 1.3 uses per-CertificateEntry + * "status_request" extension instead. + */ - Certificate serverCertificate = serverCredentials.Certificate; - Send13CertificateMessage(serverCertificate); - securityParameters.m_tlsServerEndPoint = null; - this.m_connectionState = CS_SERVER_CERTIFICATE; - } + Certificate serverCertificate = serverCredentials.Certificate; + Send13CertificateMessage(serverCertificate); + securityParameters.m_tlsServerEndPoint = null; + this.m_connectionState = CS_SERVER_CERTIFICATE; + } - // CertificateVerify - { - DigitallySigned certificateVerify = TlsUtilities.Generate13CertificateVerify(m_tlsServerContext, - serverCredentials, m_handshakeHash); - Send13CertificateVerifyMessage(certificateVerify); - this.m_connectionState = CS_CLIENT_CERTIFICATE_VERIFY; + // CertificateVerify + { + DigitallySigned certificateVerify = TlsUtilities.Generate13CertificateVerify(m_tlsServerContext, + serverCredentials, m_handshakeHash); + Send13CertificateVerifyMessage(certificateVerify); + this.m_connectionState = CS_CLIENT_CERTIFICATE_VERIFY; + } } // Finished diff --git a/crypto/test/UnitTests.csproj b/crypto/test/UnitTests.csproj index 1650a05fa..1945f1367 100644 --- a/crypto/test/UnitTests.csproj +++ b/crypto/test/UnitTests.csproj @@ -491,6 +491,7 @@ <Compile Include="src\tls\test\MockPskDtlsClient.cs" /> <Compile Include="src\tls\test\MockPskDtlsServer.cs" /> <Compile Include="src\tls\test\MockPskTls13Client.cs" /> + <Compile Include="src\tls\test\MockPskTls13Server.cs" /> <Compile Include="src\tls\test\MockPskTlsClient.cs" /> <Compile Include="src\tls\test\MockPskTlsServer.cs" /> <Compile Include="src\tls\test\MockSrpTlsClient.cs" /> @@ -501,8 +502,10 @@ <Compile Include="src\tls\test\PipedStream.cs" /> <Compile Include="src\tls\test\PrfTest.cs" /> <Compile Include="src\tls\test\PskTls13ClientTest.cs" /> + <Compile Include="src\tls\test\PskTls13ServerTest.cs" /> <Compile Include="src\tls\test\PskTlsClientTest.cs" /> <Compile Include="src\tls\test\PskTlsServerTest.cs" /> + <Compile Include="src\tls\test\Tls13PskProtocolTest.cs" /> <Compile Include="src\tls\test\TlsClientTest.cs" /> <Compile Include="src\tls\test\TlsProtocolNonBlockingTest.cs" /> <Compile Include="src\tls\test\TlsProtocolTest.cs" /> diff --git a/crypto/test/src/tls/test/MockPskTls13Server.cs b/crypto/test/src/tls/test/MockPskTls13Server.cs new file mode 100644 index 000000000..d1ea69b95 --- /dev/null +++ b/crypto/test/src/tls/test/MockPskTls13Server.cs @@ -0,0 +1,108 @@ +using System; +using System.Collections; +using System.IO; + +using Org.BouncyCastle.Tls.Crypto; +using Org.BouncyCastle.Tls.Crypto.Impl.BC; +using Org.BouncyCastle.Security; +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Tls.Tests +{ + internal class MockPskTls13Server + : AbstractTlsServer + { + internal MockPskTls13Server() + : base(new BcTlsCrypto(new SecureRandom())) + { + } + + public override TlsCredentials GetCredentials() + { + return null; + } + + protected override IList GetProtocolNames() + { + IList protocolNames = new ArrayList(); + protocolNames.Add(ProtocolName.Http_2_Tls); + protocolNames.Add(ProtocolName.Http_1_1); + return protocolNames; + } + + protected override int[] GetSupportedCipherSuites() + { + return TlsUtilities.GetSupportedCipherSuites(Crypto, + new int[] { CipherSuite.TLS_AES_128_CCM_8_SHA256, CipherSuite.TLS_AES_128_CCM_SHA256, + CipherSuite.TLS_AES_128_GCM_SHA256, CipherSuite.TLS_CHACHA20_POLY1305_SHA256 }); + } + + protected override ProtocolVersion[] GetSupportedVersions() + { + return ProtocolVersion.TLSv13.Only(); + } + + public override ProtocolVersion GetServerVersion() + { + ProtocolVersion serverVersion = base.GetServerVersion(); + + Console.WriteLine("TLS 1.3 PSK server negotiated " + serverVersion); + + return serverVersion; + } + + public override TlsPskExternal GetExternalPsk(IList identities) + { + byte[] identity = Strings.ToUtf8ByteArray("client"); + long obfuscatedTicketAge = 0L; + + PskIdentity matchIdentity = new PskIdentity(identity, obfuscatedTicketAge); + + for (int i = 0, count = identities.Count; i < count; ++i) + { + if (matchIdentity.Equals(identities[i])) + { + TlsSecret key = Crypto.CreateSecret(Strings.ToUtf8ByteArray("TLS_TEST_PSK")); + int prfAlgorithm = PrfAlgorithm.tls13_hkdf_sha256; + + return new BasicTlsPskExternal(identity, key, prfAlgorithm); + } + } + return null; + } + + public override void NotifyAlertRaised(short alertLevel, short alertDescription, string message, + Exception cause) + { + TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out; + output.WriteLine("TLS 1.3 PSK server raised alert: " + AlertLevel.GetText(alertLevel) + + ", " + AlertDescription.GetText(alertDescription)); + if (message != null) + { + output.WriteLine("> " + message); + } + if (cause != null) + { + output.WriteLine(cause); + } + } + + public override void NotifyAlertReceived(short alertLevel, short alertDescription) + { + TextWriter output = (alertLevel == AlertLevel.fatal) ? Console.Error : Console.Out; + output.WriteLine("TLS 1.3 PSK server received alert: " + AlertLevel.GetText(alertLevel) + + ", " + AlertDescription.GetText(alertDescription)); + } + + public override void NotifyHandshakeComplete() + { + base.NotifyHandshakeComplete(); + + ProtocolName protocolName = m_context.SecurityParameters.ApplicationProtocol; + if (protocolName != null) + { + Console.WriteLine("Server ALPN: " + protocolName.GetUtf8Decoding()); + } + } + } +} diff --git a/crypto/test/src/tls/test/PskTls13ServerTest.cs b/crypto/test/src/tls/test/PskTls13ServerTest.cs new file mode 100644 index 000000000..4a924b81d --- /dev/null +++ b/crypto/test/src/tls/test/PskTls13ServerTest.cs @@ -0,0 +1,77 @@ +using System; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Threading; + +using NUnit.Framework; + +using Org.BouncyCastle.Utilities.IO; + +namespace Org.BouncyCastle.Tls.Tests +{ + [TestFixture] + public class PskTls13ServerTest + { + [Test, Ignore] + public void TestConnection() + { + int port = 5556; + + TcpListener ss = new TcpListener(IPAddress.Any, port); + ss.Start(); + Stream stdout = Console.OpenStandardOutput(); + try + { + while (true) + { + TcpClient s = ss.AcceptTcpClient(); + Console.WriteLine("--------------------------------------------------------------------------------"); + Console.WriteLine("Accepted " + s); + Server serverRun = new Server(s, stdout); + Thread t = new Thread(new ThreadStart(serverRun.Run)); + t.Start(); + } + } + finally + { + ss.Stop(); + } + } + + internal class Server + { + private readonly TcpClient s; + private readonly Stream stdout; + + internal Server(TcpClient s, Stream stdout) + { + this.s = s; + this.stdout = stdout; + } + + public void Run() + { + try + { + MockPskTls13Server server = new MockPskTls13Server(); + TlsServerProtocol serverProtocol = new TlsServerProtocol(s.GetStream()); + serverProtocol.Accept(server); + Stream log = new TeeOutputStream(serverProtocol.Stream, stdout); + Streams.PipeAll(serverProtocol.Stream, log); + serverProtocol.Close(); + } + finally + { + try + { + s.Close(); + } + catch (IOException) + { + } + } + } + } + } +} diff --git a/crypto/test/src/tls/test/Tls13PskProtocolTest.cs b/crypto/test/src/tls/test/Tls13PskProtocolTest.cs new file mode 100644 index 000000000..b66e781a2 --- /dev/null +++ b/crypto/test/src/tls/test/Tls13PskProtocolTest.cs @@ -0,0 +1,75 @@ +using System; +using System.IO; +using System.Threading; + +using NUnit.Framework; + +using Org.BouncyCastle.Utilities; +using Org.BouncyCastle.Utilities.IO; + +namespace Org.BouncyCastle.Tls.Tests +{ + [TestFixture] + public class Tls13PskProtocolTest + { + [Test] + public void TestClientServer() + { + PipedStream clientPipe = new PipedStream(); + PipedStream serverPipe = new PipedStream(clientPipe); + + TlsClientProtocol clientProtocol = new TlsClientProtocol(clientPipe); + TlsServerProtocol serverProtocol = new TlsServerProtocol(serverPipe); + + MockPskTls13Client client = new MockPskTls13Client(); + MockPskTls13Server server = new MockPskTls13Server(); + + Server serverRun = new Server(serverProtocol, server); + Thread serverThread = new Thread(new ThreadStart(serverRun.Run)); + serverThread.Start(); + + clientProtocol.Connect(client); + + byte[] data = new byte[1000]; + client.Crypto.SecureRandom.NextBytes(data); + + Stream output = clientProtocol.Stream; + output.Write(data, 0, data.Length); + + byte[] echo = new byte[data.Length]; + int count = Streams.ReadFully(clientProtocol.Stream, echo); + + Assert.AreEqual(count, data.Length); + Assert.IsTrue(Arrays.AreEqual(data, echo)); + + output.Close(); + + serverThread.Join(); + } + + internal class Server + { + private readonly TlsServerProtocol m_serverProtocol; + private readonly TlsServer m_server; + + internal Server(TlsServerProtocol serverProtocol, TlsServer server) + { + this.m_serverProtocol = serverProtocol; + this.m_server = server; + } + + public void Run() + { + try + { + m_serverProtocol.Accept(m_server); + Streams.PipeAll(m_serverProtocol.Stream, m_serverProtocol.Stream); + m_serverProtocol.Close(); + } + catch (Exception) + { + } + } + } + } +} |