diff --git a/crypto/src/crypto/tls/DtlsClientProtocol.cs b/crypto/src/crypto/tls/DtlsClientProtocol.cs
index 411e7cca2..7cb554ae8 100644
--- a/crypto/src/crypto/tls/DtlsClientProtocol.cs
+++ b/crypto/src/crypto/tls/DtlsClientProtocol.cs
@@ -74,13 +74,16 @@ namespace Org.BouncyCastle.Crypto.Tls
DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.clientContext, recordLayer);
byte[] clientHelloBody = GenerateClientHello(state, state.client);
+
+ recordLayer.SetWriteVersion(ProtocolVersion.DTLSv10);
+
handshake.SendMessage(HandshakeType.client_hello, clientHelloBody);
DtlsReliableHandshake.Message serverMessage = handshake.ReceiveMessage();
while (serverMessage.Type == HandshakeType.hello_verify_request)
{
- ProtocolVersion recordLayerVersion = recordLayer.ResetDiscoveredPeerVersion();
+ ProtocolVersion recordLayerVersion = recordLayer.ReadVersion;
ProtocolVersion client_version = state.clientContext.ClientVersion;
/*
@@ -92,6 +95,8 @@ namespace Org.BouncyCastle.Crypto.Tls
if (!recordLayerVersion.IsEqualOrEarlierVersionOf(client_version))
throw new TlsFatalAlert(AlertDescription.illegal_parameter);
+ recordLayer.ReadVersion = null;
+
byte[] cookie = ProcessHelloVerifyRequest(state, serverMessage.Body);
byte[] patched = PatchClientHelloWithCookie(clientHelloBody, cookie);
@@ -103,7 +108,9 @@ namespace Org.BouncyCastle.Crypto.Tls
if (serverMessage.Type == HandshakeType.server_hello)
{
- ReportServerVersion(state, recordLayer.DiscoveredPeerVersion);
+ ProtocolVersion recordLayerVersion = recordLayer.ReadVersion;
+ ReportServerVersion(state, recordLayerVersion);
+ recordLayer.SetWriteVersion(recordLayerVersion);
ProcessServerHello(state, serverMessage.Body);
}
diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs
index 70befd9e4..6796f4cbb 100644
--- a/crypto/src/crypto/tls/DtlsRecordLayer.cs
+++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs
@@ -21,7 +21,7 @@ namespace Org.BouncyCastle.Crypto.Tls
private volatile bool mClosed = false;
private volatile bool mFailed = false;
- private volatile ProtocolVersion mDiscoveredPeerVersion = null;
+ private volatile ProtocolVersion mReadVersion = null, mWriteVersion = null;
private volatile bool mInHandshake;
private volatile int mPlaintextLimit;
private DtlsEpoch mCurrentEpoch, mPendingEpoch;
@@ -52,16 +52,15 @@ namespace Org.BouncyCastle.Crypto.Tls
this.mPlaintextLimit = plaintextLimit;
}
- internal virtual ProtocolVersion DiscoveredPeerVersion
+ internal virtual ProtocolVersion ReadVersion
{
- get { return mDiscoveredPeerVersion; }
+ get { return mReadVersion; }
+ set { this.mReadVersion = value; }
}
- internal virtual ProtocolVersion ResetDiscoveredPeerVersion()
+ internal virtual void SetWriteVersion(ProtocolVersion writeVersion)
{
- ProtocolVersion result = mDiscoveredPeerVersion;
- mDiscoveredPeerVersion = null;
- return result;
+ this.mWriteVersion = writeVersion;
}
internal virtual void InitPendingEpoch(TlsCipher pendingCipher)
@@ -199,7 +198,12 @@ namespace Org.BouncyCastle.Crypto.Tls
}
ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
- if (mDiscoveredPeerVersion != null && !mDiscoveredPeerVersion.Equals(version))
+ if (!version.IsDtls)
+ {
+ continue;
+ }
+
+ if (mReadVersion != null && !mReadVersion.Equals(version))
{
continue;
}
@@ -215,9 +219,9 @@ namespace Org.BouncyCastle.Crypto.Tls
continue;
}
- if (mDiscoveredPeerVersion == null)
+ if (mReadVersion == null)
{
- mDiscoveredPeerVersion = version;
+ mReadVersion = version;
}
switch (type)
@@ -469,6 +473,10 @@ namespace Org.BouncyCastle.Crypto.Tls
private void SendRecord(byte contentType, byte[] buf, int off, int len)
{
+ // Never send anything until a valid ClientHello has been received
+ if (mWriteVersion == null)
+ return;
+
if (len > this.mPlaintextLimit)
throw new TlsFatalAlert(AlertDescription.internal_error);
@@ -489,7 +497,7 @@ namespace Org.BouncyCastle.Crypto.Tls
byte[] record = new byte[ciphertext.Length + RECORD_HEADER_LENGTH];
TlsUtilities.WriteUint8(contentType, record, 0);
- ProtocolVersion version = mDiscoveredPeerVersion != null ? mDiscoveredPeerVersion : mContext.ClientVersion;
+ ProtocolVersion version = mWriteVersion;
TlsUtilities.WriteVersion(version, record, 1);
TlsUtilities.WriteUint16(recordEpoch, record, 3);
TlsUtilities.WriteUint48(recordSequenceNumber, record, 5);
diff --git a/crypto/src/crypto/tls/DtlsServerProtocol.cs b/crypto/src/crypto/tls/DtlsServerProtocol.cs
index c556d6320..171984b6f 100644
--- a/crypto/src/crypto/tls/DtlsServerProtocol.cs
+++ b/crypto/src/crypto/tls/DtlsServerProtocol.cs
@@ -76,12 +76,8 @@ namespace Org.BouncyCastle.Crypto.Tls
DtlsReliableHandshake.Message clientMessage = handshake.ReceiveMessage();
- {
- // NOTE: After receiving a record from the client, we discover the record layer version
- ProtocolVersion client_version = recordLayer.DiscoveredPeerVersion;
- // TODO Read RFCs for guidance on the expected record layer version number
- state.serverContext.SetClientVersion(client_version);
- }
+ // NOTE: DTLSRecordLayer requires any DTLS version, we don't otherwise constrain this
+ //ProtocolVersion recordLayerVersion = recordLayer.ReadVersion;
if (clientMessage.Type == HandshakeType.client_hello)
{
@@ -97,6 +93,10 @@ namespace Org.BouncyCastle.Crypto.Tls
ApplyMaxFragmentLengthExtension(recordLayer, securityParameters.maxFragmentLength);
+ ProtocolVersion recordLayerVersion = state.serverContext.ServerVersion;
+ recordLayer.ReadVersion = recordLayerVersion;
+ recordLayer.SetWriteVersion(recordLayerVersion);
+
handshake.SendMessage(HandshakeType.server_hello, serverHelloBody);
}
diff --git a/crypto/test/src/crypto/tls/test/DtlsTestSuite.cs b/crypto/test/src/crypto/tls/test/DtlsTestSuite.cs
index e9e4411af..a1ba62dde 100644
--- a/crypto/test/src/crypto/tls/test/DtlsTestSuite.cs
+++ b/crypto/test/src/crypto/tls/test/DtlsTestSuite.cs
@@ -203,11 +203,6 @@ namespace Org.BouncyCastle.Crypto.Tls.Tests
private static void AddTestCase(IList testSuite, TlsTestConfig config, String name)
{
- //testSuite.Add(new TestCaseData(config).SetName(name));
- }
-
- private static void AddTestCaseDebug(IList testSuite, TlsTestConfig config, String name)
- {
testSuite.Add(new TestCaseData(config).SetName(name));
}
@@ -215,11 +210,7 @@ namespace Org.BouncyCastle.Crypto.Tls.Tests
{
TlsTestConfig c = new TlsTestConfig();
c.clientMinimumVersion = ProtocolVersion.DTLSv10;
- /*
- * TODO We'd like to just set the offer version to DTLSv12, but there is a known issue with
- * overly-restrictive version checks b/w BC DTLS 1.2 client, BC DTLS 1.0 server
- */
- c.clientOfferVersion = version;
+ c.clientOfferVersion = ProtocolVersion.DTLSv12;
c.serverMaximumVersion = version;
c.serverMinimumVersion = ProtocolVersion.DTLSv10;
return c;
diff --git a/crypto/test/src/crypto/tls/test/MockDtlsClient.cs b/crypto/test/src/crypto/tls/test/MockDtlsClient.cs
index e3c604db7..25057b8ce 100644
--- a/crypto/test/src/crypto/tls/test/MockDtlsClient.cs
+++ b/crypto/test/src/crypto/tls/test/MockDtlsClient.cs
@@ -73,8 +73,13 @@ namespace Org.BouncyCastle.Crypto.Tls.Tests
IDictionary clientExtensions = TlsExtensionsUtilities.EnsureExtensionsInitialised(base.GetClientExtensions());
TlsExtensionsUtilities.AddEncryptThenMacExtension(clientExtensions);
TlsExtensionsUtilities.AddExtendedMasterSecretExtension(clientExtensions);
- TlsExtensionsUtilities.AddMaxFragmentLengthExtension(clientExtensions, MaxFragmentLength.pow2_9);
- TlsExtensionsUtilities.AddTruncatedHMacExtension(clientExtensions);
+ {
+ /*
+ * NOTE: If you are copying test code, do not blindly set these extensions in your own client.
+ */
+ TlsExtensionsUtilities.AddMaxFragmentLengthExtension(clientExtensions, MaxFragmentLength.pow2_9);
+ TlsExtensionsUtilities.AddTruncatedHMacExtension(clientExtensions);
+ }
return clientExtensions;
}
diff --git a/crypto/test/src/crypto/tls/test/MockTlsClient.cs b/crypto/test/src/crypto/tls/test/MockTlsClient.cs
index 7c1198632..35c5b3599 100644
--- a/crypto/test/src/crypto/tls/test/MockTlsClient.cs
+++ b/crypto/test/src/crypto/tls/test/MockTlsClient.cs
@@ -63,8 +63,13 @@ namespace Org.BouncyCastle.Crypto.Tls.Tests
IDictionary clientExtensions = TlsExtensionsUtilities.EnsureExtensionsInitialised(base.GetClientExtensions());
TlsExtensionsUtilities.AddEncryptThenMacExtension(clientExtensions);
TlsExtensionsUtilities.AddExtendedMasterSecretExtension(clientExtensions);
- TlsExtensionsUtilities.AddMaxFragmentLengthExtension(clientExtensions, MaxFragmentLength.pow2_9);
- TlsExtensionsUtilities.AddTruncatedHMacExtension(clientExtensions);
+ {
+ /*
+ * NOTE: If you are copying test code, do not blindly set these extensions in your own client.
+ */
+ TlsExtensionsUtilities.AddMaxFragmentLengthExtension(clientExtensions, MaxFragmentLength.pow2_9);
+ TlsExtensionsUtilities.AddTruncatedHMacExtension(clientExtensions);
+ }
return clientExtensions;
}
|