summary refs log tree commit diff
path: root/crypto/src
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src')
-rw-r--r--crypto/src/crypto/tls/AbstractTlsPeer.cs5
-rw-r--r--crypto/src/crypto/tls/DtlsClientProtocol.cs3
-rw-r--r--crypto/src/crypto/tls/DtlsRecordLayer.cs7
-rw-r--r--crypto/src/crypto/tls/DtlsReliableHandshake.cs13
-rw-r--r--crypto/src/crypto/tls/DtlsServerProtocol.cs4
-rw-r--r--crypto/src/crypto/tls/TlsPeer.cs9
6 files changed, 30 insertions, 11 deletions
diff --git a/crypto/src/crypto/tls/AbstractTlsPeer.cs b/crypto/src/crypto/tls/AbstractTlsPeer.cs
index 2081ce8e5..e7bfc1742 100644
--- a/crypto/src/crypto/tls/AbstractTlsPeer.cs
+++ b/crypto/src/crypto/tls/AbstractTlsPeer.cs
@@ -23,6 +23,11 @@ namespace Org.BouncyCastle.Crypto.Tls
             this.mCloseHandle = closeHandle;
         }
 
+        public virtual int GetHandshakeTimeoutMillis()
+        {
+            return 0;
+        }
+
         public virtual bool RequiresExtendedMasterSecret()
         {
             return false;
diff --git a/crypto/src/crypto/tls/DtlsClientProtocol.cs b/crypto/src/crypto/tls/DtlsClientProtocol.cs
index 4c08bbcfc..fe6381dfa 100644
--- a/crypto/src/crypto/tls/DtlsClientProtocol.cs
+++ b/crypto/src/crypto/tls/DtlsClientProtocol.cs
@@ -82,7 +82,8 @@ namespace Org.BouncyCastle.Crypto.Tls
         internal virtual DtlsTransport ClientHandshake(ClientHandshakeState state, DtlsRecordLayer recordLayer)
         {
             SecurityParameters securityParameters = state.clientContext.SecurityParameters;
-            DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.clientContext, recordLayer);
+            DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.clientContext, recordLayer,
+                state.client.GetHandshakeTimeoutMillis());
 
             byte[] clientHelloBody = GenerateClientHello(state, state.client);
 
diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs
index c1a26b14f..5f3ec9e9c 100644
--- a/crypto/src/crypto/tls/DtlsRecordLayer.cs
+++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs
@@ -152,12 +152,7 @@ namespace Org.BouncyCastle.Crypto.Tls
         {
             long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
 
-            Timeout timeout = null;
-            if (waitMillis > 0)
-            {
-                timeout = new Timeout(waitMillis, currentTimeMillis);
-            }
-
+            Timeout timeout = Timeout.ForWaitMillis(waitMillis, currentTimeMillis); 
             byte[] record = null;
 
             while (waitMillis >= 0)
diff --git a/crypto/src/crypto/tls/DtlsReliableHandshake.cs b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
index 3eeb8a61e..4fc351376 100644
--- a/crypto/src/crypto/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
@@ -16,6 +16,7 @@ namespace Org.BouncyCastle.Crypto.Tls
         private const int MaxResendMillis = 60000;
 
         private readonly DtlsRecordLayer mRecordLayer;
+        private readonly Timeout mHandshakeTimeout;
 
         private TlsHandshakeHash mHandshakeHash;
 
@@ -28,9 +29,10 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         private int mMessageSeq = 0, mNextReceiveSeq = 0;
 
-        internal DtlsReliableHandshake(TlsContext context, DtlsRecordLayer transport)
+        internal DtlsReliableHandshake(TlsContext context, DtlsRecordLayer transport, int timeoutMillis)
         {
             this.mRecordLayer = transport;
+            this.mHandshakeTimeout = Timeout.ForWaitMillis(timeoutMillis); 
             this.mHandshakeHash = new DeferredHash();
             this.mHandshakeHash.Init(context);
         }
@@ -85,7 +87,6 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         internal Message ReceiveMessage()
         {
-            // TODO Add support for "overall" handshake timeout
             long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
 
             if (mResendTimeout == null)
@@ -107,7 +108,15 @@ namespace Org.BouncyCastle.Crypto.Tls
                 if (pending != null)
                     return pending;
 
+                int handshakeMillis = Timeout.GetWaitMillis(mHandshakeTimeout, currentTimeMillis);
+                if (handshakeMillis < 0)
+                    throw new TlsFatalAlert(AlertDescription.handshake_failure);
+
                 int waitMillis = System.Math.Max(1, Timeout.GetWaitMillis(mResendTimeout, currentTimeMillis));
+                if (handshakeMillis > 0)
+                {
+                    waitMillis = System.Math.Min(waitMillis, handshakeMillis);
+                }
 
                 int receiveLimit = mRecordLayer.GetReceiveLimit();
                 if (buf == null || buf.Length < receiveLimit)
diff --git a/crypto/src/crypto/tls/DtlsServerProtocol.cs b/crypto/src/crypto/tls/DtlsServerProtocol.cs
index 242e1bee5..b4ed75198 100644
--- a/crypto/src/crypto/tls/DtlsServerProtocol.cs
+++ b/crypto/src/crypto/tls/DtlsServerProtocol.cs
@@ -83,8 +83,8 @@ namespace Org.BouncyCastle.Crypto.Tls
         internal virtual DtlsTransport ServerHandshake(ServerHandshakeState state, DtlsRecordLayer recordLayer)
         {
             SecurityParameters securityParameters = state.serverContext.SecurityParameters;
-            DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.serverContext, recordLayer);
-
+            DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.serverContext, recordLayer,
+                state.server.GetHandshakeTimeoutMillis());
             DtlsReliableHandshake.Message clientMessage = handshake.ReceiveMessage();
 
             // NOTE: DTLSRecordLayer requires any DTLS version, we don't otherwise constrain this
diff --git a/crypto/src/crypto/tls/TlsPeer.cs b/crypto/src/crypto/tls/TlsPeer.cs
index a1e99f3fd..817871b14 100644
--- a/crypto/src/crypto/tls/TlsPeer.cs
+++ b/crypto/src/crypto/tls/TlsPeer.cs
@@ -11,6 +11,15 @@ namespace Org.BouncyCastle.Crypto.Tls
         void Cancel();
 
         /// <summary>
+        /// Specify the timeout, in milliseconds, to use for the complete handshake process.
+        /// </summary>
+        /// <remarks>
+        /// Negative values are not allowed. A timeout of zero means an infinite timeout (i.e. the
+        /// handshake will never time out). NOTE: Currently only respected by DTLS protocols.
+        /// </remarks>
+        int GetHandshakeTimeoutMillis();
+
+        /// <summary>
         /// This implementation supports RFC 7627 and will always negotiate the extended_master_secret
         /// extension where possible.
         /// </summary>