summary refs log tree commit diff
path: root/crypto/src
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src')
-rw-r--r--crypto/src/crypto/tls/DtlsClientProtocol.cs11
-rw-r--r--crypto/src/crypto/tls/DtlsRecordLayer.cs30
-rw-r--r--crypto/src/crypto/tls/DtlsServerProtocol.cs12
3 files changed, 34 insertions, 19 deletions
diff --git a/crypto/src/crypto/tls/DtlsClientProtocol.cs b/crypto/src/crypto/tls/DtlsClientProtocol.cs
index 411e7cca2..7cb554ae8 100644
--- a/crypto/src/crypto/tls/DtlsClientProtocol.cs
+++ b/crypto/src/crypto/tls/DtlsClientProtocol.cs
@@ -74,13 +74,16 @@ namespace Org.BouncyCastle.Crypto.Tls
             DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.clientContext, recordLayer);
 
             byte[] clientHelloBody = GenerateClientHello(state, state.client);
+
+            recordLayer.SetWriteVersion(ProtocolVersion.DTLSv10);
+
             handshake.SendMessage(HandshakeType.client_hello, clientHelloBody);
 
             DtlsReliableHandshake.Message serverMessage = handshake.ReceiveMessage();
 
             while (serverMessage.Type == HandshakeType.hello_verify_request)
             {
-                ProtocolVersion recordLayerVersion = recordLayer.ResetDiscoveredPeerVersion();
+                ProtocolVersion recordLayerVersion = recordLayer.ReadVersion;
                 ProtocolVersion client_version = state.clientContext.ClientVersion;
 
                 /*
@@ -92,6 +95,8 @@ namespace Org.BouncyCastle.Crypto.Tls
                 if (!recordLayerVersion.IsEqualOrEarlierVersionOf(client_version))
                     throw new TlsFatalAlert(AlertDescription.illegal_parameter);
 
+                recordLayer.ReadVersion = null;
+
                 byte[] cookie = ProcessHelloVerifyRequest(state, serverMessage.Body);
                 byte[] patched = PatchClientHelloWithCookie(clientHelloBody, cookie);
 
@@ -103,7 +108,9 @@ namespace Org.BouncyCastle.Crypto.Tls
 
             if (serverMessage.Type == HandshakeType.server_hello)
             {
-                ReportServerVersion(state, recordLayer.DiscoveredPeerVersion);
+                ProtocolVersion recordLayerVersion = recordLayer.ReadVersion;
+                ReportServerVersion(state, recordLayerVersion);
+                recordLayer.SetWriteVersion(recordLayerVersion);
 
                 ProcessServerHello(state, serverMessage.Body);
             }
diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs
index 70befd9e4..6796f4cbb 100644
--- a/crypto/src/crypto/tls/DtlsRecordLayer.cs
+++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs
@@ -21,7 +21,7 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         private volatile bool mClosed = false;
         private volatile bool mFailed = false;
-        private volatile ProtocolVersion mDiscoveredPeerVersion = null;
+        private volatile ProtocolVersion mReadVersion = null, mWriteVersion = null;
         private volatile bool mInHandshake;
         private volatile int mPlaintextLimit;
         private DtlsEpoch mCurrentEpoch, mPendingEpoch;
@@ -52,16 +52,15 @@ namespace Org.BouncyCastle.Crypto.Tls
             this.mPlaintextLimit = plaintextLimit;
         }
 
-        internal virtual ProtocolVersion DiscoveredPeerVersion
+        internal virtual ProtocolVersion ReadVersion
         {
-            get { return mDiscoveredPeerVersion; }
+            get { return mReadVersion; }
+            set { this.mReadVersion = value; }
         }
 
-        internal virtual ProtocolVersion ResetDiscoveredPeerVersion()
+        internal virtual void SetWriteVersion(ProtocolVersion writeVersion)
         {
-            ProtocolVersion result = mDiscoveredPeerVersion;
-            mDiscoveredPeerVersion = null;
-            return result;
+            this.mWriteVersion = writeVersion;
         }
 
         internal virtual void InitPendingEpoch(TlsCipher pendingCipher)
@@ -199,7 +198,12 @@ namespace Org.BouncyCastle.Crypto.Tls
                     }
 
                     ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
-                    if (mDiscoveredPeerVersion != null && !mDiscoveredPeerVersion.Equals(version))
+                    if (!version.IsDtls)
+                    {
+                        continue;
+                    }
+
+                    if (mReadVersion != null && !mReadVersion.Equals(version))
                     {
                         continue;
                     }
@@ -215,9 +219,9 @@ namespace Org.BouncyCastle.Crypto.Tls
                         continue;
                     }
 
-                    if (mDiscoveredPeerVersion == null)
+                    if (mReadVersion == null)
                     {
-                        mDiscoveredPeerVersion = version;
+                        mReadVersion = version;
                     }
 
                     switch (type)
@@ -469,6 +473,10 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         private void SendRecord(byte contentType, byte[] buf, int off, int len)
         {
+            // Never send anything until a valid ClientHello has been received
+            if (mWriteVersion == null)
+                return;
+
             if (len > this.mPlaintextLimit)
                 throw new TlsFatalAlert(AlertDescription.internal_error);
 
@@ -489,7 +497,7 @@ namespace Org.BouncyCastle.Crypto.Tls
 
             byte[] record = new byte[ciphertext.Length + RECORD_HEADER_LENGTH];
             TlsUtilities.WriteUint8(contentType, record, 0);
-            ProtocolVersion version = mDiscoveredPeerVersion != null ? mDiscoveredPeerVersion : mContext.ClientVersion;
+            ProtocolVersion version = mWriteVersion;
             TlsUtilities.WriteVersion(version, record, 1);
             TlsUtilities.WriteUint16(recordEpoch, record, 3);
             TlsUtilities.WriteUint48(recordSequenceNumber, record, 5);
diff --git a/crypto/src/crypto/tls/DtlsServerProtocol.cs b/crypto/src/crypto/tls/DtlsServerProtocol.cs
index c556d6320..171984b6f 100644
--- a/crypto/src/crypto/tls/DtlsServerProtocol.cs
+++ b/crypto/src/crypto/tls/DtlsServerProtocol.cs
@@ -76,12 +76,8 @@ namespace Org.BouncyCastle.Crypto.Tls
 
             DtlsReliableHandshake.Message clientMessage = handshake.ReceiveMessage();
 
-            {
-                // NOTE: After receiving a record from the client, we discover the record layer version
-                ProtocolVersion client_version = recordLayer.DiscoveredPeerVersion;
-                // TODO Read RFCs for guidance on the expected record layer version number
-                state.serverContext.SetClientVersion(client_version);
-            }
+            // NOTE: DTLSRecordLayer requires any DTLS version, we don't otherwise constrain this
+            //ProtocolVersion recordLayerVersion = recordLayer.ReadVersion;
 
             if (clientMessage.Type == HandshakeType.client_hello)
             {
@@ -97,6 +93,10 @@ namespace Org.BouncyCastle.Crypto.Tls
 
                 ApplyMaxFragmentLengthExtension(recordLayer, securityParameters.maxFragmentLength);
 
+                ProtocolVersion recordLayerVersion = state.serverContext.ServerVersion;
+                recordLayer.ReadVersion = recordLayerVersion;
+                recordLayer.SetWriteVersion(recordLayerVersion);
+
                 handshake.SendMessage(HandshakeType.server_hello, serverHelloBody);
             }