summary refs log tree commit diff
path: root/crypto/test/src/tls/test/DtlsTestCase.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/test/src/tls/test/DtlsTestCase.cs')
-rw-r--r--crypto/test/src/tls/test/DtlsTestCase.cs164
1 files changed, 164 insertions, 0 deletions
diff --git a/crypto/test/src/tls/test/DtlsTestCase.cs b/crypto/test/src/tls/test/DtlsTestCase.cs
new file mode 100644
index 000000000..d93f17c27
--- /dev/null
+++ b/crypto/test/src/tls/test/DtlsTestCase.cs
@@ -0,0 +1,164 @@
+using System;
+using System.Threading;
+
+using NUnit.Framework;
+
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+    [TestFixture]
+    public class DtlsTestCase
+    {
+        private static void CheckDtlsVersions(ProtocolVersion[] versions)
+        {
+            if (versions != null)
+            {
+                for (int i = 0; i < versions.Length; ++i)
+                {
+                    if (!versions[i].IsDtls)
+                        throw new InvalidOperationException("Non-DTLS version");
+                }
+            }
+        }
+
+        [Test, TestCaseSource(typeof(DtlsTestSuite), "Suite")]
+        public void RunTest(TlsTestConfig config)
+        {
+            CheckDtlsVersions(config.clientSupportedVersions);
+            CheckDtlsVersions(config.serverSupportedVersions);
+
+            DtlsTestClientProtocol clientProtocol = new DtlsTestClientProtocol(config);
+            DtlsTestServerProtocol serverProtocol = new DtlsTestServerProtocol(config);
+
+            MockDatagramAssociation network = new MockDatagramAssociation(1500);
+
+            TlsTestClientImpl clientImpl = new TlsTestClientImpl(config);
+            TlsTestServerImpl serverImpl = new TlsTestServerImpl(config);
+
+            Server server = new Server(this, serverProtocol, network.Server, serverImpl);
+
+            Thread serverThread = new Thread(new ThreadStart(server.Run));
+            serverThread.Start();
+
+            Exception caught = null;
+            try
+            {
+                DatagramTransport clientTransport = network.Client;
+
+                if (TlsTestConfig.Debug)
+                {
+                    clientTransport = new LoggingDatagramTransport(clientTransport, Console.Out);
+                }
+
+                DtlsTransport dtlsClient = clientProtocol.Connect(clientImpl, clientTransport);
+
+                for (int i = 1; i <= 10; ++i)
+                {
+                    byte[] data = new byte[i];
+                    Arrays.Fill(data, (byte)i);
+                    dtlsClient.Send(data, 0, data.Length);
+                }
+
+                byte[] buf = new byte[dtlsClient.GetReceiveLimit()];
+                while (dtlsClient.Receive(buf, 0, buf.Length, 100) >= 0)
+                {
+                }
+
+                dtlsClient.Close();
+            }
+            catch (Exception e)
+            {
+                caught = e;
+                LogException(caught);
+            }
+
+            server.Shutdown(serverThread);
+
+            // TODO Add checks that the various streams were closed
+
+            Assert.AreEqual(config.expectFatalAlertConnectionEnd, clientImpl.FirstFatalAlertConnectionEnd,
+                "Client fatal alert connection end");
+            Assert.AreEqual(config.expectFatalAlertConnectionEnd, serverImpl.FirstFatalAlertConnectionEnd,
+                "Server fatal alert connection end");
+
+            Assert.AreEqual(config.expectFatalAlertDescription, clientImpl.FirstFatalAlertDescription,
+                "Client fatal alert description");
+            Assert.AreEqual(config.expectFatalAlertDescription, serverImpl.FirstFatalAlertDescription,
+                "Server fatal alert description");
+
+            if (config.expectFatalAlertConnectionEnd == -1)
+            {
+                Assert.IsNull(caught, "Unexpected client exception");
+                Assert.IsNull(server.Caught, "Unexpected server exception");
+            }
+        }
+
+        protected void LogException(Exception e)
+        {
+            if (TlsTestConfig.Debug)
+            {
+                Console.Error.WriteLine(e);
+                Console.Error.Flush();
+            }
+        }
+
+        internal class Server
+        {
+            private readonly DtlsTestCase m_outer;
+            private readonly DtlsTestServerProtocol m_serverProtocol;
+            private readonly DatagramTransport m_serverTransport;
+            private readonly TlsTestServerImpl m_serverImpl;
+
+            private volatile bool m_isShutdown = false;
+            private Exception m_caught = null;
+
+            internal Server(DtlsTestCase outer, DtlsTestServerProtocol serverProtocol,
+                DatagramTransport serverTransport, TlsTestServerImpl serverImpl)
+            {
+                this.m_outer = outer;
+                this.m_serverProtocol = serverProtocol;
+                this.m_serverTransport = serverTransport;
+                this.m_serverImpl = serverImpl;
+            }
+
+            public void Run()
+            {
+                try
+                {
+                    DtlsTransport dtlsServer = m_serverProtocol.Accept(m_serverImpl, m_serverTransport);
+                    byte[] buf = new byte[dtlsServer.GetReceiveLimit()];
+                    while (!m_isShutdown)
+                    {
+                        int length = dtlsServer.Receive(buf, 0, buf.Length, 100);
+                        if (length >= 0)
+                        {
+                            dtlsServer.Send(buf, 0, length);
+                        }
+                    }
+                    dtlsServer.Close();
+                }
+                catch (Exception e)
+                {
+                    this.m_caught = e;
+                    m_outer.LogException(m_caught);
+                }
+            }
+
+            internal void Shutdown(Thread serverThread)
+            {
+                if (!m_isShutdown)
+                {
+                    this.m_isShutdown = true;
+                    //serverThread.Interrupt();
+                    serverThread.Join();
+                }
+            }
+
+            internal Exception Caught
+            {
+                get { return m_caught; }
+            }
+        }
+    }
+}