diff options
Diffstat (limited to 'crypto/src')
-rw-r--r-- | crypto/src/crypto/tls/DtlsClientProtocol.cs | 11 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsRecordLayer.cs | 30 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsServerProtocol.cs | 12 |
3 files changed, 34 insertions, 19 deletions
diff --git a/crypto/src/crypto/tls/DtlsClientProtocol.cs b/crypto/src/crypto/tls/DtlsClientProtocol.cs index 411e7cca2..7cb554ae8 100644 --- a/crypto/src/crypto/tls/DtlsClientProtocol.cs +++ b/crypto/src/crypto/tls/DtlsClientProtocol.cs @@ -74,13 +74,16 @@ namespace Org.BouncyCastle.Crypto.Tls DtlsReliableHandshake handshake = new DtlsReliableHandshake(state.clientContext, recordLayer); byte[] clientHelloBody = GenerateClientHello(state, state.client); + + recordLayer.SetWriteVersion(ProtocolVersion.DTLSv10); + handshake.SendMessage(HandshakeType.client_hello, clientHelloBody); DtlsReliableHandshake.Message serverMessage = handshake.ReceiveMessage(); while (serverMessage.Type == HandshakeType.hello_verify_request) { - ProtocolVersion recordLayerVersion = recordLayer.ResetDiscoveredPeerVersion(); + ProtocolVersion recordLayerVersion = recordLayer.ReadVersion; ProtocolVersion client_version = state.clientContext.ClientVersion; /* @@ -92,6 +95,8 @@ namespace Org.BouncyCastle.Crypto.Tls if (!recordLayerVersion.IsEqualOrEarlierVersionOf(client_version)) throw new TlsFatalAlert(AlertDescription.illegal_parameter); + recordLayer.ReadVersion = null; + byte[] cookie = ProcessHelloVerifyRequest(state, serverMessage.Body); byte[] patched = PatchClientHelloWithCookie(clientHelloBody, cookie); @@ -103,7 +108,9 @@ namespace Org.BouncyCastle.Crypto.Tls if (serverMessage.Type == HandshakeType.server_hello) { - ReportServerVersion(state, recordLayer.DiscoveredPeerVersion); + ProtocolVersion recordLayerVersion = recordLayer.ReadVersion; + ReportServerVersion(state, recordLayerVersion); + recordLayer.SetWriteVersion(recordLayerVersion); ProcessServerHello(state, serverMessage.Body); } diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs index 70befd9e4..6796f4cbb 100644 --- a/crypto/src/crypto/tls/DtlsRecordLayer.cs +++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs @@ -21,7 +21,7 @@ namespace Org.BouncyCastle.Crypto.Tls private volatile bool mClosed = false; private volatile bool mFailed = false; - private volatile ProtocolVersion mDiscoveredPeerVersion = null; + private volatile ProtocolVersion mReadVersion = null, mWriteVersion = null; private volatile bool mInHandshake; private volatile int mPlaintextLimit; private DtlsEpoch mCurrentEpoch, mPendingEpoch; @@ -52,16 +52,15 @@ namespace Org.BouncyCastle.Crypto.Tls this.mPlaintextLimit = plaintextLimit; } - internal virtual ProtocolVersion DiscoveredPeerVersion + internal virtual ProtocolVersion ReadVersion { - get { return mDiscoveredPeerVersion; } + get { return mReadVersion; } + set { this.mReadVersion = value; } } - internal virtual ProtocolVersion ResetDiscoveredPeerVersion() + internal virtual void SetWriteVersion(ProtocolVersion writeVersion) { - ProtocolVersion result = mDiscoveredPeerVersion; - mDiscoveredPeerVersion = null; - return result; + this.mWriteVersion = writeVersion; } internal virtual void InitPendingEpoch(TlsCipher pendingCipher) @@ -199,7 +198,12 @@ namespace Org.BouncyCastle.Crypto.Tls } ProtocolVersion version = TlsUtilities.ReadVersion(record, 1); - if (mDiscoveredPeerVersion != null && !mDiscoveredPeerVersion.Equals(version)) + if (!version.IsDtls) + { + continue; + } + + if (mReadVersion != null && !mReadVersion.Equals(version)) { continue; } @@ -215,9 +219,9 @@ namespace Org.BouncyCastle.Crypto.Tls continue; } - if (mDiscoveredPeerVersion == null) + if (mReadVersion == null) { - mDiscoveredPeerVersion = version; + mReadVersion = version; } switch (type) @@ -469,6 +473,10 @@ namespace Org.BouncyCastle.Crypto.Tls private void SendRecord(byte contentType, byte[] buf, int off, int len) { + // Never send anything until a valid ClientHello has been received + if (mWriteVersion == null) + return; + if (len > this.mPlaintextLimit) throw new TlsFatalAlert(AlertDescription.internal_error); @@ -489,7 +497,7 @@ namespace Org.BouncyCastle.Crypto.Tls byte[] record = new byte[ciphertext.Length + RECORD_HEADER_LENGTH]; TlsUtilities.WriteUint8(contentType, record, 0); - ProtocolVersion version = mDiscoveredPeerVersion != null ? mDiscoveredPeerVersion : mContext.ClientVersion; + ProtocolVersion version = mWriteVersion; TlsUtilities.WriteVersion(version, record, 1); TlsUtilities.WriteUint16(recordEpoch, record, 3); TlsUtilities.WriteUint48(recordSequenceNumber, record, 5); diff --git a/crypto/src/crypto/tls/DtlsServerProtocol.cs b/crypto/src/crypto/tls/DtlsServerProtocol.cs index c556d6320..171984b6f 100644 --- a/crypto/src/crypto/tls/DtlsServerProtocol.cs +++ b/crypto/src/crypto/tls/DtlsServerProtocol.cs @@ -76,12 +76,8 @@ namespace Org.BouncyCastle.Crypto.Tls DtlsReliableHandshake.Message clientMessage = handshake.ReceiveMessage(); - { - // NOTE: After receiving a record from the client, we discover the record layer version - ProtocolVersion client_version = recordLayer.DiscoveredPeerVersion; - // TODO Read RFCs for guidance on the expected record layer version number - state.serverContext.SetClientVersion(client_version); - } + // NOTE: DTLSRecordLayer requires any DTLS version, we don't otherwise constrain this + //ProtocolVersion recordLayerVersion = recordLayer.ReadVersion; if (clientMessage.Type == HandshakeType.client_hello) { @@ -97,6 +93,10 @@ namespace Org.BouncyCastle.Crypto.Tls ApplyMaxFragmentLengthExtension(recordLayer, securityParameters.maxFragmentLength); + ProtocolVersion recordLayerVersion = state.serverContext.ServerVersion; + recordLayer.ReadVersion = recordLayerVersion; + recordLayer.SetWriteVersion(recordLayerVersion); + handshake.SendMessage(HandshakeType.server_hello, serverHelloBody); } |