diff options
Diffstat (limited to 'crypto/src/tls/DtlsRecordLayer.cs')
-rw-r--r-- | crypto/src/tls/DtlsRecordLayer.cs | 51 |
1 files changed, 39 insertions, 12 deletions
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs index a786da127..82fc3db64 100644 --- a/crypto/src/tls/DtlsRecordLayer.cs +++ b/crypto/src/tls/DtlsRecordLayer.cs @@ -274,8 +274,14 @@ namespace Org.BouncyCastle.Tls /// <exception cref="IOException"/> public virtual int Receive(byte[] buf, int off, int len, int waitMillis) { + return Receive(buf, off, len, waitMillis, null); + } + + /// <exception cref="IOException"/> + internal int Receive(byte[] buf, int off, int len, int waitMillis, Action recordCallback) + { #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER - return Receive(buf.AsSpan(off, len), waitMillis); + return Receive(buf.AsSpan(off, len), waitMillis, recordCallback); #else long currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); @@ -329,7 +335,7 @@ namespace Org.BouncyCastle.Tls } int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); - int processed = ProcessRecord(received, record, buf, off, len); + int processed = ProcessRecord(received, record, buf, off, len, recordCallback); if (processed >= 0) return processed; @@ -342,10 +348,10 @@ namespace Org.BouncyCastle.Tls } /// <exception cref="IOException"/> - internal int ReceivePending(byte[] buf, int off, int len) + internal int ReceivePending(byte[] buf, int off, int len, Action recordCallback) { #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER - return ReceivePending(buf.AsSpan(off, len)); + return ReceivePending(buf.AsSpan(off, len), recordCallback); #else if (m_recordQueue.Available > 0) { @@ -355,7 +361,7 @@ namespace Org.BouncyCastle.Tls do { int received = ReceivePendingRecord(record, 0, receiveLimit); - int processed = ProcessRecord(received, record, buf, off, len); + int processed = ProcessRecord(received, record, buf, off, len, recordCallback); if (processed >= 0) return processed; } @@ -370,6 +376,12 @@ namespace Org.BouncyCastle.Tls /// <exception cref="IOException"/> public virtual int Receive(Span<byte> buffer, int waitMillis) { + return Receive(buffer, waitMillis, null); + } + + /// <exception cref="IOException"/> + internal int Receive(Span<byte> buffer, int waitMillis, Action recordCallback) + { long currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); Timeout timeout = Timeout.ForWaitMillis(waitMillis, currentTimeMillis); @@ -422,7 +434,7 @@ namespace Org.BouncyCastle.Tls } int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); - int processed = ProcessRecord(received, record, buffer); + int processed = ProcessRecord(received, record, buffer, recordCallback); if (processed >= 0) return processed; @@ -434,7 +446,7 @@ namespace Org.BouncyCastle.Tls } /// <exception cref="IOException"/> - internal int ReceivePending(Span<byte> buffer) + internal int ReceivePending(Span<byte> buffer, Action recordCallback) { if (m_recordQueue.Available > 0) { @@ -444,7 +456,7 @@ namespace Org.BouncyCastle.Tls do { int received = ReceivePendingRecord(record, 0, receiveLimit); - int processed = ProcessRecord(received, record, buffer); + int processed = ProcessRecord(received, record, buffer, recordCallback); if (processed >= 0) return processed; } @@ -665,9 +677,9 @@ namespace Org.BouncyCastle.Tls // TODO Include 'currentTimeMillis' as an argument, use with Timeout, resetHeartbeat /// <exception cref="IOException"/> #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER - private int ProcessRecord(int received, byte[] record, Span<byte> buffer) + private int ProcessRecord(int received, byte[] record, Span<byte> buffer, Action recordCallback) #else - private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len) + private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len, Action recordCallback) #endif { // NOTE: received < 0 (timeout) is covered by this first case @@ -772,8 +784,6 @@ namespace Org.BouncyCastle.Tls return -1; } - recordEpoch.ReplayWindow.ReportAuthenticated(seq); - if (decoded.len > m_plaintextLimit) return -1; @@ -805,6 +815,23 @@ namespace Org.BouncyCastle.Tls } } + recordEpoch.ReplayWindow.ReportAuthenticated(seq, out var isLatestConfirmed); + + /* + * NOTE: The record has passed record layer validation and will be dispatched according to the decoded + * content type. + */ + if (recordCallback != null) + { + // TODO Make the callback more general than just peer address update + if (ContentType.tls12_cid == recordType && + isLatestConfirmed && + recordEpoch == m_readEpoch) + { + recordCallback(); + } + } + switch (decoded.contentType) { case ContentType.alert: |