summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2023-07-12 13:45:24 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2023-07-12 13:45:24 +0700
commit8c5383ea76e88a0957f7a903698720a8ef7af0f4 (patch)
tree71506c0fa4ff860b488abfbc989a4914f35077d8
parentReduce DTLS test timeouts (diff)
downloadBouncyCastle.NET-ed25519-8c5383ea76e88a0957f7a903698720a8ef7af0f4.tar.xz
Add DtlsRawKeysProtocolTest
-rw-r--r--crypto/test/src/tls/test/DtlsRawKeysProtocolTest.cs353
1 files changed, 353 insertions, 0 deletions
diff --git a/crypto/test/src/tls/test/DtlsRawKeysProtocolTest.cs b/crypto/test/src/tls/test/DtlsRawKeysProtocolTest.cs
new file mode 100644
index 000000000..36e22b58d
--- /dev/null
+++ b/crypto/test/src/tls/test/DtlsRawKeysProtocolTest.cs
@@ -0,0 +1,353 @@
+using System;
+using System.Text;
+using System.Threading;
+
+using NUnit.Framework;
+
+using Org.BouncyCastle.Crypto.Parameters;
+using Org.BouncyCastle.Security;
+using Org.BouncyCastle.Tls.Crypto;
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+    [TestFixture]
+    public class DtlsRawKeysProtocolTest
+    {
+        private readonly SecureRandom Random = new SecureRandom();
+
+        [Test]
+        public void TestClientSendsExtensionButServerDoesNotSupportIt()
+        {
+            TestClientSendsExtensionButServerDoesNotSupportIt(ProtocolVersion.DTLSv12);
+        }
+
+        // TODO[dtls13]
+        //[Test]
+        //public void TestClientSendsExtensionButServerDoesNotSupportIt_13()
+        //{
+        //    TestClientSendsExtensionButServerDoesNotSupportIt(ProtocolVersion.DTLSv13);
+        //}
+
+        private void TestClientSendsExtensionButServerDoesNotSupportIt(ProtocolVersion tlsVersion)
+        {
+            MockRawKeysTlsClient client = new MockRawKeysTlsClient(CertificateType.X509, -1,
+                new short[]{ CertificateType.RawPublicKey, CertificateType.X509 }, null, GenerateKeyPair(),
+                tlsVersion);
+            MockRawKeysTlsServer server = new MockRawKeysTlsServer(CertificateType.X509, -1, null, GenerateKeyPair(),
+                tlsVersion);
+            PumpData(client, server);
+        }
+
+        [Test]
+        public void TestExtensionsAreOmittedIfSpecifiedButOnlyContainX509()
+        {
+            TestExtensionsAreOmittedIfSpecifiedButOnlyContainX509(ProtocolVersion.DTLSv12);
+        }
+
+        // TODO[dtls13]
+        //[Test]
+        //public void TestExtensionsAreOmittedIfSpecifiedButOnlyContainX509_13()
+        //{
+        //    TestExtensionsAreOmittedIfSpecifiedButOnlyContainX509(ProtocolVersion.DTLSv13);
+        //}
+
+        private void TestExtensionsAreOmittedIfSpecifiedButOnlyContainX509(ProtocolVersion tlsVersion)
+        {
+            MockRawKeysTlsClient client = new MockRawKeysTlsClient(CertificateType.X509, CertificateType.X509,
+                new short[]{ CertificateType.X509 }, new short[]{ CertificateType.X509 }, GenerateKeyPair(),
+                tlsVersion);
+            MockRawKeysTlsServer server = new MockRawKeysTlsServer(CertificateType.X509, CertificateType.X509,
+                new short[]{ CertificateType.X509 }, GenerateKeyPair(), tlsVersion);
+            PumpData(client, server);
+
+            Assert.IsFalse(server.m_receivedClientExtensions.ContainsKey(ExtensionType.client_certificate_type),
+                "client cert type extension should not be sent");
+            Assert.IsFalse(server.m_receivedClientExtensions.ContainsKey(ExtensionType.server_certificate_type),
+                "server cert type extension should not be sent");
+        }
+
+        [Test]
+        public void TestBothSidesUseRawKey()
+        {
+            TestBothSidesUseRawKey(ProtocolVersion.DTLSv12);
+        }
+
+        // TODO[dtls13]
+        //[Test]
+        //public void TestBothSidesUseRawKey_13()
+        //{
+        //    TestBothSidesUseRawKey(ProtocolVersion.DTLSv13);
+        //}
+
+        private void TestBothSidesUseRawKey(ProtocolVersion tlsVersion)
+        {
+            MockRawKeysTlsClient client = new MockRawKeysTlsClient(CertificateType.RawPublicKey,
+                CertificateType.RawPublicKey, new short[]{ CertificateType.RawPublicKey },
+                new short[]{ CertificateType.RawPublicKey }, GenerateKeyPair(), tlsVersion);
+            MockRawKeysTlsServer server = new MockRawKeysTlsServer(CertificateType.RawPublicKey,
+                CertificateType.RawPublicKey, new short[]{ CertificateType.RawPublicKey }, GenerateKeyPair(),
+                tlsVersion);
+            PumpData(client, server);
+        }
+
+        [Test]
+        public void TestServerUsesRawKeyAndClientIsAnonymous()
+        {
+            TestServerUsesRawKeyAndClientIsAnonymous(ProtocolVersion.DTLSv12);
+        }
+
+        // TODO[dtls13]
+        //[Test]
+        //public void TestServerUsesRawKeyAndClientIsAnonymous_13()
+        //{
+        //    TestServerUsesRawKeyAndClientIsAnonymous(ProtocolVersion.DTLSv13);
+        //}
+
+        private void TestServerUsesRawKeyAndClientIsAnonymous(ProtocolVersion tlsVersion)
+        {
+            MockRawKeysTlsClient client = new MockRawKeysTlsClient(CertificateType.RawPublicKey, -1,
+                new short[]{ CertificateType.RawPublicKey }, null, GenerateKeyPair(), tlsVersion);
+            MockRawKeysTlsServer server = new MockRawKeysTlsServer(CertificateType.RawPublicKey, -1, null,
+                GenerateKeyPair(), tlsVersion);
+            PumpData(client, server);
+        }
+
+        [Test]
+        public void TestServerUsesRawKeyAndClientUsesX509()
+        {
+            TestServerUsesRawKeyAndClientUsesX509(ProtocolVersion.DTLSv12);
+        }
+
+        // TODO[dtls13]
+        //[Test]
+        //public void TestServerUsesRawKeyAndClientUsesX509_13()
+        //{
+        //    TestServerUsesRawKeyAndClientUsesX509(ProtocolVersion.DTLSv13);
+        //}
+
+        private void TestServerUsesRawKeyAndClientUsesX509(ProtocolVersion tlsVersion)
+        {
+            MockRawKeysTlsClient client = new MockRawKeysTlsClient(CertificateType.RawPublicKey,
+                CertificateType.X509, new short[]{ CertificateType.RawPublicKey }, null, GenerateKeyPair(),
+                tlsVersion);
+            MockRawKeysTlsServer server = new MockRawKeysTlsServer(CertificateType.RawPublicKey,
+                CertificateType.X509, null, GenerateKeyPair(), tlsVersion);
+            PumpData(client, server);
+        }
+
+        [Test]
+        public void TestServerUsesX509AndClientUsesRawKey()
+        {
+            TestServerUsesX509AndClientUsesRawKey(ProtocolVersion.DTLSv12);
+        }
+
+        // TODO[dtls13]
+        //[Test]
+        //public void TestServerUsesX509AndClientUsesRawKey_13()
+        //{
+        //    TestServerUsesX509AndClientUsesRawKey(ProtocolVersion.DTLSv13);
+        //}
+
+        private void TestServerUsesX509AndClientUsesRawKey(ProtocolVersion tlsVersion)
+        {
+            MockRawKeysTlsClient client = new MockRawKeysTlsClient(CertificateType.X509, CertificateType.RawPublicKey,
+                null, new short[]{ CertificateType.RawPublicKey }, GenerateKeyPair(), tlsVersion);
+            MockRawKeysTlsServer server = new MockRawKeysTlsServer(CertificateType.X509, CertificateType.RawPublicKey,
+                new short[]{ CertificateType.RawPublicKey }, GenerateKeyPair(), tlsVersion);
+            PumpData(client, server);
+        }
+
+        // NOTE: Test disabled because of problems getting a clean exit of the DTLS server after a fatal alert.
+/*
+        [Test]
+        public void TestClientSendsClientCertExtensionButServerHasNoCommonTypes()
+        {
+            TestClientSendsClientCertExtensionButServerHasNoCommonTypes(ProtocolVersion.DTLSv12);
+        }
+
+        // TODO[dtls13]
+        //[Test]
+        //public void TestClientSendsClientCertExtensionButServerHasNoCommonTypes_13()
+        //{
+        //    TestClientSendsClientCertExtensionButServerHasNoCommonTypes(ProtocolVersion.DTLSv13);
+        //}
+
+        private void TestClientSendsClientCertExtensionButServerHasNoCommonTypes(ProtocolVersion tlsVersion)
+        {
+            try
+            {
+                MockRawKeysTlsClient client = new MockRawKeysTlsClient(CertificateType.X509,
+                    CertificateType.RawPublicKey, null, new short[]{ CertificateType.RawPublicKey }, GenerateKeyPair(),
+                    tlsVersion);
+                MockRawKeysTlsServer server = new MockRawKeysTlsServer(CertificateType.X509, CertificateType.X509,
+                    new short[]{ CertificateType.X509 }, GenerateKeyPair(), tlsVersion);
+                PumpData(client, server);
+                Assert.Fail("Should have caused unsupported_certificate alert");
+            }
+            catch (TlsFatalAlertReceived alert)
+            {
+                Assert.AreEqual(AlertDescription.unsupported_certificate, alert.AlertDescription,
+                    "Should have caused unsupported_certificate alert");
+            }
+        }
+*/
+
+        // NOTE: Test disabled because of problems getting a clean exit of the DTLS server after a fatal alert.
+/*
+        [Test]
+        public void TestClientSendsServerCertExtensionButServerHasNoCommonTypes()
+        {
+            TestClientSendsServerCertExtensionButServerHasNoCommonTypes(ProtocolVersion.DTLSv12);
+        }
+
+        // TODO[dtls13]
+        //[Test]
+        //public void TestClientSendsServerCertExtensionButServerHasNoCommonTypes_13()
+        //{
+        //    TestClientSendsServerCertExtensionButServerHasNoCommonTypes(ProtocolVersion.DTLSv13);
+        //}
+
+        private void TestClientSendsServerCertExtensionButServerHasNoCommonTypes(ProtocolVersion tlsVersion)
+        {
+            try
+            {
+                MockRawKeysTlsClient client = new MockRawKeysTlsClient(CertificateType.RawPublicKey,
+                    CertificateType.RawPublicKey, new short[]{ CertificateType.RawPublicKey }, null, GenerateKeyPair(),
+                    tlsVersion);
+                MockRawKeysTlsServer server = new MockRawKeysTlsServer(CertificateType.X509,
+                    CertificateType.RawPublicKey, new short[]{ CertificateType.RawPublicKey }, GenerateKeyPair(),
+                    tlsVersion);
+                PumpData(client, server);
+                Assert.Fail("Should have caused unsupported_certificate alert");
+            }
+            catch (TlsFatalAlertReceived alert)
+            {
+                Assert.AreEqual(AlertDescription.unsupported_certificate, alert.AlertDescription,
+                    "Should have caused unsupported_certificate alert");
+            }
+        }
+*/
+
+        private Ed25519PrivateKeyParameters GenerateKeyPair()
+        {
+            return new Ed25519PrivateKeyParameters(Random);
+        }
+
+        private void PumpData(TlsClient client, TlsServer server)
+        {
+            DtlsClientProtocol clientProtocol = new DtlsClientProtocol();
+            DtlsServerProtocol serverProtocol = new DtlsServerProtocol();
+
+            MockDatagramAssociation network = new MockDatagramAssociation(1500);
+
+            ServerTask serverTask = new ServerTask(serverProtocol, server, network.Server);
+
+            Thread serverThread = new Thread(new ThreadStart(serverTask.Run));
+            serverThread.Start();
+
+            DatagramTransport clientTransport = network.Client;
+
+            clientTransport = new UnreliableDatagramTransport(clientTransport, Random, 0, 0);
+
+            clientTransport = new LoggingDatagramTransport(clientTransport, Console.Out);
+
+            DtlsTransport dtlsClient = clientProtocol.Connect(client, clientTransport);
+
+            for (int i = 1; i <= 10; ++i)
+            {
+                byte[] data = new byte[i];
+                Arrays.Fill(data, (byte)i);
+                dtlsClient.Send(data, 0, data.Length);
+            }
+
+            byte[] buf = new byte[dtlsClient.GetReceiveLimit()];
+            while (dtlsClient.Receive(buf, 0, buf.Length, 100) >= 0)
+            {
+            }
+
+            dtlsClient.Close();
+
+            serverTask.Shutdown(serverThread);
+        }
+
+        internal class ServerTask
+        {
+            private readonly DtlsServerProtocol m_serverProtocol;
+            private readonly TlsServer m_server;
+            private readonly DatagramTransport m_serverTransport;
+            private volatile bool m_isShutdown = false;
+
+            internal ServerTask(DtlsServerProtocol serverProtocol, TlsServer server, DatagramTransport serverTransport)
+            {
+                this.m_serverProtocol = serverProtocol;
+                this.m_server = server;
+                this.m_serverTransport = serverTransport;
+            }
+
+            public void Run()
+            {
+                try
+                {
+                    TlsCrypto serverCrypto = m_server.Crypto;
+
+                    DtlsRequest request = null;
+
+                    // Use DtlsVerifier to require a HelloVerifyRequest cookie exchange before accepting
+                    {
+                        DtlsVerifier verifier = new DtlsVerifier(serverCrypto);
+
+                        // NOTE: Test value only - would typically be the client IP address
+                        byte[] clientID = Encoding.UTF8.GetBytes("MockRawKeysTlsClient");
+
+                        int receiveLimit = m_serverTransport.GetReceiveLimit();
+                        int dummyOffset = serverCrypto.SecureRandom.Next(16) + 1;
+                        byte[] buf = new byte[dummyOffset + m_serverTransport.GetReceiveLimit()];
+
+                        do
+                        {
+                            if (m_isShutdown)
+                                return;
+
+                            int length = m_serverTransport.Receive(buf, dummyOffset, receiveLimit, 100);
+                            if (length > 0)
+                            {
+                                request = verifier.VerifyRequest(clientID, buf, dummyOffset, length, m_serverTransport);
+                            }
+                        }
+                        while (request == null);
+                    }
+
+                    // NOTE: A real server would handle each DtlsRequest in a new task/thread and continue accepting
+                    {
+                        DtlsTransport dtlsTransport = m_serverProtocol.Accept(m_server, m_serverTransport, request);
+                        byte[] buf = new byte[dtlsTransport.GetReceiveLimit()];
+                        while (!m_isShutdown)
+                        {
+                            int length = dtlsTransport.Receive(buf, 0, buf.Length, 100);
+                            if (length >= 0)
+                            {
+                                dtlsTransport.Send(buf, 0, length);
+                            }
+                        }
+                        dtlsTransport.Close();
+                    }
+                }
+                catch (Exception e)
+                {
+                    Console.Error.WriteLine(e);
+                    Console.Error.Flush();
+                }
+            }
+
+            internal void Shutdown(Thread serverThread)
+            {
+                if (!m_isShutdown)
+                {
+                    this.m_isShutdown = true;
+                    serverThread.Join();
+                }
+            }
+        }
+    }
+}