diff --git a/crypto/src/tls/DeferredHash.cs b/crypto/src/tls/DeferredHash.cs
index bba3019a1..f97e9c088 100644
--- a/crypto/src/tls/DeferredHash.cs
+++ b/crypto/src/tls/DeferredHash.cs
@@ -16,7 +16,7 @@ namespace Org.BouncyCastle.Tls
private readonly TlsContext m_context;
private DigestInputBuffer m_buf;
- private readonly IDictionary m_hashes;
+ private IDictionary m_hashes;
private bool m_forceBuffering;
private bool m_sealed;
@@ -29,21 +29,12 @@ namespace Org.BouncyCastle.Tls
this.m_sealed = false;
}
- private DeferredHash(TlsContext context, IDictionary hashes)
- {
- this.m_context = context;
- this.m_buf = null;
- this.m_hashes = hashes;
- this.m_forceBuffering = false;
- this.m_sealed = true;
- }
-
/// <exception cref="IOException"/>
public void CopyBufferTo(Stream output)
{
if (m_buf == null)
{
- // If you see this, you need to call forceBuffering() before SealHashAlgorithms()
+ // If you see this, you need to call ForceBuffering() before SealHashAlgorithms()
throw new InvalidOperationException("Not buffering");
}
@@ -96,7 +87,7 @@ namespace Org.BouncyCastle.Tls
CheckStopBuffering();
}
- public TlsHandshakeHash StopTracking()
+ public void StopTracking()
{
SecurityParameters securityParameters = m_context.SecurityParameters;
@@ -116,7 +107,11 @@ namespace Org.BouncyCastle.Tls
break;
}
}
- return new DeferredHash(m_context, newHashes);
+
+ this.m_buf = null;
+ this.m_hashes = newHashes;
+ this.m_forceBuffering = false;
+ this.m_sealed = true;
}
public TlsHash ForkPrfHash()
diff --git a/crypto/src/tls/DtlsReliableHandshake.cs b/crypto/src/tls/DtlsReliableHandshake.cs
index e27d72762..58b9301fd 100644
--- a/crypto/src/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/tls/DtlsReliableHandshake.cs
@@ -142,11 +142,9 @@ namespace Org.BouncyCastle.Tls
get { return m_handshakeHash; }
}
- internal TlsHandshakeHash PrepareToFinish()
+ internal void PrepareToFinish()
{
- TlsHandshakeHash result = m_handshakeHash;
- this.m_handshakeHash = m_handshakeHash.StopTracking();
- return result;
+ m_handshakeHash.StopTracking();
}
/// <exception cref="IOException"/>
@@ -173,69 +171,63 @@ namespace Org.BouncyCastle.Tls
}
/// <exception cref="IOException"/>
+ internal Message ReceiveMessage()
+ {
+ Message message = ImplReceiveMessage();
+ UpdateHandshakeMessagesDigest(message);
+ return message;
+ }
+
+ /// <exception cref="IOException"/>
internal byte[] ReceiveMessageBody(short msg_type)
{
- Message message = ReceiveMessage();
+ Message message = ImplReceiveMessage();
if (message.Type != msg_type)
throw new TlsFatalAlert(AlertDescription.unexpected_message);
+ UpdateHandshakeMessagesDigest(message);
return message.Body;
}
/// <exception cref="IOException"/>
- internal Message ReceiveMessage()
+ internal Message ReceiveMessageDelayedDigest(short msg_type)
{
- long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
-
- if (null == m_resendTimeout)
- {
- m_resendMillis = INITIAL_RESEND_MILLIS;
- m_resendTimeout = new Timeout(m_resendMillis, currentTimeMillis);
-
- PrepareInboundFlight(Platform.CreateHashtable());
- }
+ Message message = ImplReceiveMessage();
+ if (message.Type != msg_type)
+ throw new TlsFatalAlert(AlertDescription.unexpected_message);
- byte[] buf = null;
+ return message;
+ }
- for (;;)
+ /// <exception cref="IOException"/>
+ internal Message UpdateHandshakeMessagesDigest(Message message)
+ {
+ short msg_type = message.Type;
+ switch (msg_type)
{
- if (m_recordLayer.IsClosed)
- throw new TlsFatalAlert(AlertDescription.user_canceled);
-
- Message pending = GetPendingMessage();
- if (pending != null)
- return pending;
-
- if (Timeout.HasExpired(m_handshakeTimeout, currentTimeMillis))
- throw new TlsTimeoutException("Handshake timed out");
-
- int waitMillis = Timeout.GetWaitMillis(m_handshakeTimeout, currentTimeMillis);
- waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_resendTimeout, currentTimeMillis);
-
- // NOTE: Ensure a finite wait, of at least 1ms
- if (waitMillis < 1)
- {
- waitMillis = 1;
- }
-
- int receiveLimit = m_recordLayer.GetReceiveLimit();
- if (buf == null || buf.Length < receiveLimit)
- {
- buf = new byte[receiveLimit];
- }
-
- int received = m_recordLayer.Receive(buf, 0, receiveLimit, waitMillis);
- if (received < 0)
- {
- ResendOutboundFlight();
- }
- else
- {
- ProcessRecord(MAX_RECEIVE_AHEAD, m_recordLayer.ReadEpoch, buf, 0, received);
- }
+ case HandshakeType.hello_request:
+ case HandshakeType.hello_verify_request:
+ case HandshakeType.key_update:
+ break;
- currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+ // TODO[dtls13] Not included in the transcript for (D)TLS 1.3+
+ case HandshakeType.new_session_ticket:
+ default:
+ {
+ byte[] body = message.Body;
+ byte[] buf = new byte[MESSAGE_HEADER_LENGTH];
+ TlsUtilities.WriteUint8(msg_type, buf, 0);
+ TlsUtilities.WriteUint24(body.Length, buf, 1);
+ TlsUtilities.WriteUint16(message.Seq, buf, 4);
+ TlsUtilities.WriteUint24(0, buf, 6);
+ TlsUtilities.WriteUint24(body.Length, buf, 9);
+ m_handshakeHash.Update(buf, 0, buf.Length);
+ m_handshakeHash.Update(body, 0, body.Length);
+ break;
+ }
}
+
+ return message;
}
internal void Finish()
@@ -297,12 +289,68 @@ namespace Org.BouncyCastle.Tls
if (body != null)
{
m_previousInboundFlight = null;
- return UpdateHandshakeMessagesDigest(new Message(m_next_receive_seq++, next.MsgType, body));
+ return new Message(m_next_receive_seq++, next.MsgType, body);
}
}
return null;
}
+ /// <exception cref="IOException"/>
+ private Message ImplReceiveMessage()
+ {
+ long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+
+ if (null == m_resendTimeout)
+ {
+ m_resendMillis = INITIAL_RESEND_MILLIS;
+ m_resendTimeout = new Timeout(m_resendMillis, currentTimeMillis);
+
+ PrepareInboundFlight(Platform.CreateHashtable());
+ }
+
+ byte[] buf = null;
+
+ for (; ; )
+ {
+ if (m_recordLayer.IsClosed)
+ throw new TlsFatalAlert(AlertDescription.user_canceled);
+
+ Message pending = GetPendingMessage();
+ if (pending != null)
+ return pending;
+
+ if (Timeout.HasExpired(m_handshakeTimeout, currentTimeMillis))
+ throw new TlsTimeoutException("Handshake timed out");
+
+ int waitMillis = Timeout.GetWaitMillis(m_handshakeTimeout, currentTimeMillis);
+ waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_resendTimeout, currentTimeMillis);
+
+ // NOTE: Ensure a finite wait, of at least 1ms
+ if (waitMillis < 1)
+ {
+ waitMillis = 1;
+ }
+
+ int receiveLimit = m_recordLayer.GetReceiveLimit();
+ if (buf == null || buf.Length < receiveLimit)
+ {
+ buf = new byte[receiveLimit];
+ }
+
+ int received = m_recordLayer.Receive(buf, 0, receiveLimit, waitMillis);
+ if (received < 0)
+ {
+ ResendOutboundFlight();
+ }
+ else
+ {
+ ProcessRecord(MAX_RECEIVE_AHEAD, m_recordLayer.ReadEpoch, buf, 0, received);
+ }
+
+ currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+ }
+ }
+
private void PrepareInboundFlight(IDictionary nextFlight)
{
ResetAll(m_currentInboundFlight);
@@ -400,37 +448,6 @@ namespace Org.BouncyCastle.Tls
}
/// <exception cref="IOException"/>
- private Message UpdateHandshakeMessagesDigest(Message message)
- {
- short msg_type = message.Type;
- switch (msg_type)
- {
- case HandshakeType.hello_request:
- case HandshakeType.hello_verify_request:
- case HandshakeType.key_update:
- break;
-
- // TODO[dtls13] Not included in the transcript for (D)TLS 1.3+
- case HandshakeType.new_session_ticket:
- default:
- {
- byte[] body = message.Body;
- byte[] buf = new byte[MESSAGE_HEADER_LENGTH];
- TlsUtilities.WriteUint8(msg_type, buf, 0);
- TlsUtilities.WriteUint24(body.Length, buf, 1);
- TlsUtilities.WriteUint16(message.Seq, buf, 4);
- TlsUtilities.WriteUint24(0, buf, 6);
- TlsUtilities.WriteUint24(body.Length, buf, 9);
- m_handshakeHash.Update(buf, 0, buf.Length);
- m_handshakeHash.Update(body, 0, body.Length);
- break;
- }
- }
-
- return message;
- }
-
- /// <exception cref="IOException"/>
private void WriteMessage(Message message)
{
int sendLimit = m_recordLayer.GetSendLimit();
diff --git a/crypto/src/tls/DtlsServerProtocol.cs b/crypto/src/tls/DtlsServerProtocol.cs
index 99c47ba1b..b49122423 100644
--- a/crypto/src/tls/DtlsServerProtocol.cs
+++ b/crypto/src/tls/DtlsServerProtocol.cs
@@ -297,12 +297,17 @@ namespace Org.BouncyCastle.Tls
* parameters).
*/
{
- TlsHandshakeHash certificateVerifyHash = handshake.PrepareToFinish();
-
if (ExpectCertificateVerifyMessage(state))
{
- byte[] certificateVerifyBody = handshake.ReceiveMessageBody(HandshakeType.certificate_verify);
- ProcessCertificateVerify(state, certificateVerifyBody, certificateVerifyHash);
+ clientMessage = handshake.ReceiveMessageDelayedDigest(HandshakeType.certificate_verify);
+ byte[] certificateVerifyBody = clientMessage.Body;
+ ProcessCertificateVerify(state, certificateVerifyBody, handshake.HandshakeHash);
+ handshake.PrepareToFinish();
+ handshake.UpdateHandshakeMessagesDigest(clientMessage);
+ }
+ else
+ {
+ handshake.PrepareToFinish();
}
}
diff --git a/crypto/src/tls/TlsClientProtocol.cs b/crypto/src/tls/TlsClientProtocol.cs
index 19e2eda3d..cb59289ae 100644
--- a/crypto/src/tls/TlsClientProtocol.cs
+++ b/crypto/src/tls/TlsClientProtocol.cs
@@ -609,7 +609,7 @@ namespace Org.BouncyCastle.Tls
this.m_connectionState = CS_CLIENT_CERTIFICATE_VERIFY;
}
- this.m_handshakeHash = m_handshakeHash.StopTracking();
+ m_handshakeHash.StopTracking();
SendChangeCipherSpec();
SendFinishedMessage();
diff --git a/crypto/src/tls/TlsHandshakeHash.cs b/crypto/src/tls/TlsHandshakeHash.cs
index aa33c680d..88aeaaa32 100644
--- a/crypto/src/tls/TlsHandshakeHash.cs
+++ b/crypto/src/tls/TlsHandshakeHash.cs
@@ -20,7 +20,7 @@ namespace Org.BouncyCastle.Tls
void SealHashAlgorithms();
- TlsHandshakeHash StopTracking();
+ void StopTracking();
TlsHash ForkPrfHash();
diff --git a/crypto/src/tls/TlsServerProtocol.cs b/crypto/src/tls/TlsServerProtocol.cs
index 22700a277..0ab8a7a98 100644
--- a/crypto/src/tls/TlsServerProtocol.cs
+++ b/crypto/src/tls/TlsServerProtocol.cs
@@ -1322,7 +1322,7 @@ namespace Org.BouncyCastle.Tls
TlsUtilities.VerifyCertificateVerifyClient(m_tlsServerContext, m_certificateRequest, certificateVerify,
m_handshakeHash);
- this.m_handshakeHash = m_handshakeHash.StopTracking();
+ m_handshakeHash.StopTracking();
}
/// <exception cref="IOException"/>
@@ -1357,7 +1357,7 @@ namespace Org.BouncyCastle.Tls
if (!ExpectCertificateVerifyMessage())
{
- this.m_handshakeHash = m_handshakeHash.StopTracking();
+ m_handshakeHash.StopTracking();
}
}
|