using System;
using System.Diagnostics;
using System.IO;
using System.Runtime.ExceptionServices;
using Org.BouncyCastle.Tls.Crypto;
using Org.BouncyCastle.Tls.Crypto.Impl;
using Org.BouncyCastle.Utilities;
namespace Org.BouncyCastle.Tls
{
/// An implementation of the TLS 1.0/1.1/1.2 record layer.
internal sealed class RecordStream
{
private const int DefaultPlaintextLimit = (1 << 14);
private readonly Record m_inputRecord = new Record();
private readonly SequenceNumber m_readSeqNo = new SequenceNumber(), m_writeSeqNo = new SequenceNumber();
private readonly TlsProtocol m_handler;
private readonly Stream m_input;
private readonly Stream m_output;
private TlsCipher m_pendingCipher = null;
private TlsCipher m_readCipher = TlsNullNullCipher.Instance;
private TlsCipher m_readCipherDeferred = null;
private TlsCipher m_writeCipher = TlsNullNullCipher.Instance;
private ProtocolVersion m_writeVersion = null;
private int m_plaintextLimit = DefaultPlaintextLimit;
private int m_ciphertextLimit = DefaultPlaintextLimit;
private bool m_ignoreChangeCipherSpec = false;
internal RecordStream(TlsProtocol handler, Stream input, Stream output)
{
this.m_handler = handler;
this.m_input = input;
this.m_output = output;
}
internal int PlaintextLimit
{
get { return m_plaintextLimit; }
}
internal void SetPlaintextLimit(int plaintextLimit)
{
this.m_plaintextLimit = plaintextLimit;
this.m_ciphertextLimit = m_readCipher.GetCiphertextDecodeLimit(plaintextLimit);
}
internal void SetWriteVersion(ProtocolVersion writeVersion)
{
this.m_writeVersion = writeVersion;
}
internal void SetIgnoreChangeCipherSpec(bool ignoreChangeCipherSpec)
{
this.m_ignoreChangeCipherSpec = ignoreChangeCipherSpec;
}
internal void SetPendingCipher(TlsCipher tlsCipher)
{
this.m_pendingCipher = tlsCipher;
}
///
internal void NotifyChangeCipherSpecReceived()
{
if (m_pendingCipher == null)
throw new TlsFatalAlert(AlertDescription.unexpected_message, "No pending cipher");
EnablePendingCipherRead(false);
}
///
internal void EnablePendingCipherRead(bool deferred)
{
if (m_pendingCipher == null)
throw new TlsFatalAlert(AlertDescription.internal_error);
if (m_readCipherDeferred != null)
throw new TlsFatalAlert(AlertDescription.internal_error);
if (deferred)
{
this.m_readCipherDeferred = m_pendingCipher;
}
else
{
this.m_readCipher = m_pendingCipher;
this.m_ciphertextLimit = m_readCipher.GetCiphertextDecodeLimit(m_plaintextLimit);
m_readSeqNo.Reset();
}
}
///
internal void EnablePendingCipherWrite()
{
if (m_pendingCipher == null)
throw new TlsFatalAlert(AlertDescription.internal_error);
this.m_writeCipher = this.m_pendingCipher;
m_writeSeqNo.Reset();
}
///
internal void FinaliseHandshake()
{
if (m_readCipher != m_pendingCipher || m_writeCipher != m_pendingCipher)
throw new TlsFatalAlert(AlertDescription.handshake_failure);
this.m_pendingCipher = null;
}
internal bool NeedsKeyUpdate()
{
return m_writeSeqNo.CurrentValue >= (1L << 20);
}
///
internal void NotifyKeyUpdateReceived()
{
m_readCipher.RekeyDecoder();
m_readSeqNo.Reset();
}
///
internal void NotifyKeyUpdateSent()
{
m_writeCipher.RekeyEncoder();
m_writeSeqNo.Reset();
}
///
internal RecordPreview PreviewRecordHeader(byte[] recordHeader)
{
short recordType = CheckRecordType(recordHeader, RecordFormat.TypeOffset);
//ProtocolVersion recordVersion = TlsUtilities.ReadVersion(recordHeader, RecordFormat.VersionOffset);
int length = TlsUtilities.ReadUint16(recordHeader, RecordFormat.LengthOffset);
CheckLength(length, m_ciphertextLimit, AlertDescription.record_overflow);
int recordSize = RecordFormat.FragmentOffset + length;
int applicationDataLimit = 0;
// NOTE: For TLS 1.3, this only MIGHT be application data
if (ContentType.application_data == recordType && m_handler.IsApplicationDataReady)
{
var cipher = m_readCipher;
int plaintextDecodeLimit;
if (cipher is TlsCipherExt tlsCipherExt)
{
plaintextDecodeLimit = tlsCipherExt.GetPlaintextDecodeLimit(length);
}
else
{
plaintextDecodeLimit = cipher.GetPlaintextLimit(length);
}
applicationDataLimit = System.Math.Max(0, System.Math.Min(m_plaintextLimit, plaintextDecodeLimit));
}
return new RecordPreview(recordSize, applicationDataLimit);
}
internal RecordPreview PreviewOutputRecord(int contentLength)
{
int contentLimit = System.Math.Max(0, System.Math.Min(m_plaintextLimit, contentLength));
int recordSize = PreviewOutputRecordSize(contentLimit);
return new RecordPreview(recordSize, contentLimit);
}
internal int PreviewOutputRecordSize(int contentLength)
{
Debug.Assert(contentLength <= m_plaintextLimit);
return RecordFormat.FragmentOffset + m_writeCipher.GetCiphertextEncodeLimit(contentLength, m_plaintextLimit);
}
///
internal bool ReadFullRecord(byte[] input, int inputOff, int inputLen)
{
if (inputLen < RecordFormat.FragmentOffset)
return false;
int length = TlsUtilities.ReadUint16(input, inputOff + RecordFormat.LengthOffset);
if (inputLen != (RecordFormat.FragmentOffset + length))
return false;
short recordType = CheckRecordType(input, inputOff + RecordFormat.TypeOffset);
ProtocolVersion recordVersion = TlsUtilities.ReadVersion(input, inputOff + RecordFormat.VersionOffset);
CheckLength(length, m_ciphertextLimit, AlertDescription.record_overflow);
if (m_ignoreChangeCipherSpec && ContentType.change_cipher_spec == recordType)
{
CheckChangeCipherSpec(input, inputOff + RecordFormat.FragmentOffset, length);
return true;
}
TlsDecodeResult decoded = DecodeAndVerify(recordType, recordVersion, input,
inputOff + RecordFormat.FragmentOffset, length);
m_handler.ProcessRecord(decoded.contentType, decoded.buf, decoded.off, decoded.len);
return true;
}
///
internal bool ReadRecord()
{
if (!m_inputRecord.ReadHeader(m_input))
return false;
short recordType = CheckRecordType(m_inputRecord.m_buf, RecordFormat.TypeOffset);
ProtocolVersion recordVersion = TlsUtilities.ReadVersion(m_inputRecord.m_buf, RecordFormat.VersionOffset);
int length = TlsUtilities.ReadUint16(m_inputRecord.m_buf, RecordFormat.LengthOffset);
CheckLength(length, m_ciphertextLimit, AlertDescription.record_overflow);
m_inputRecord.ReadFragment(m_input, length);
TlsDecodeResult decoded;
try
{
if (m_ignoreChangeCipherSpec && ContentType.change_cipher_spec == recordType)
{
CheckChangeCipherSpec(m_inputRecord.m_buf, RecordFormat.FragmentOffset, length);
return true;
}
decoded = DecodeAndVerify(recordType, recordVersion, m_inputRecord.m_buf, RecordFormat.FragmentOffset,
length);
}
finally
{
m_inputRecord.Reset();
}
m_handler.ProcessRecord(decoded.contentType, decoded.buf, decoded.off, decoded.len);
return true;
}
///
internal TlsDecodeResult DecodeAndVerify(short recordType, ProtocolVersion recordVersion, byte[] ciphertext,
int off, int len)
{
long seqNo = m_readSeqNo.NextValue(AlertDescription.unexpected_message);
TlsDecodeResult decoded = m_readCipher.DecodeCiphertext(seqNo, recordType, recordVersion, ciphertext, off,
len);
CheckLength(decoded.len, m_plaintextLimit, AlertDescription.record_overflow);
/*
* RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert,
* or ChangeCipherSpec content types.
*/
if (decoded.len < 1 && decoded.contentType != ContentType.application_data)
throw new TlsFatalAlert(AlertDescription.illegal_parameter);
return decoded;
}
///
internal void WriteRecord(short contentType, byte[] plaintext, int plaintextOffset, int plaintextLength)
{
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
WriteRecord(contentType, plaintext.AsSpan(plaintextOffset, plaintextLength));
#else
// Never send anything until a valid ClientHello has been received
if (m_writeVersion == null)
return;
/*
* RFC 5246 6.2.1 The length should not exceed 2^14.
*/
CheckLength(plaintextLength, m_plaintextLimit, AlertDescription.internal_error);
/*
* RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert,
* or ChangeCipherSpec content types.
*/
if (plaintextLength < 1 && contentType != ContentType.application_data)
throw new TlsFatalAlert(AlertDescription.internal_error);
long seqNo = m_writeSeqNo.NextValue(AlertDescription.internal_error);
ProtocolVersion recordVersion = m_writeVersion;
TlsEncodeResult encoded = m_writeCipher.EncodePlaintext(seqNo, contentType, recordVersion,
RecordFormat.FragmentOffset, plaintext, plaintextOffset, plaintextLength);
int ciphertextLength = encoded.len - RecordFormat.FragmentOffset;
TlsUtilities.CheckUint16(ciphertextLength);
TlsUtilities.WriteUint8(encoded.recordType, encoded.buf, encoded.off + RecordFormat.TypeOffset);
TlsUtilities.WriteVersion(recordVersion, encoded.buf, encoded.off + RecordFormat.VersionOffset);
TlsUtilities.WriteUint16(ciphertextLength, encoded.buf, encoded.off + RecordFormat.LengthOffset);
// TODO[tls-port] Can we support interrupted IO on .NET?
//try
//{
m_output.Write(encoded.buf, encoded.off, encoded.len);
//}
//catch (InterruptedIOException e)
//{
// throw new TlsFatalAlert(AlertDescription.internal_error, e);
//}
m_output.Flush();
#endif
}
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
///
internal void WriteRecord(short contentType, ReadOnlySpan plaintext)
{
// Never send anything until a valid ClientHello has been received
if (m_writeVersion == null)
return;
/*
* RFC 5246 6.2.1 The length should not exceed 2^14.
*/
CheckLength(plaintext.Length, m_plaintextLimit, AlertDescription.internal_error);
/*
* RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert,
* or ChangeCipherSpec content types.
*/
if (plaintext.Length < 1 && contentType != ContentType.application_data)
throw new TlsFatalAlert(AlertDescription.internal_error);
long seqNo = m_writeSeqNo.NextValue(AlertDescription.internal_error);
ProtocolVersion recordVersion = m_writeVersion;
TlsEncodeResult encoded = m_writeCipher.EncodePlaintext(seqNo, contentType, recordVersion,
RecordFormat.FragmentOffset, plaintext);
int ciphertextLength = encoded.len - RecordFormat.FragmentOffset;
TlsUtilities.CheckUint16(ciphertextLength);
TlsUtilities.WriteUint8(encoded.recordType, encoded.buf, encoded.off + RecordFormat.TypeOffset);
TlsUtilities.WriteVersion(recordVersion, encoded.buf, encoded.off + RecordFormat.VersionOffset);
TlsUtilities.WriteUint16(ciphertextLength, encoded.buf, encoded.off + RecordFormat.LengthOffset);
// TODO[tls-port] Can we support interrupted IO on .NET?
//try
//{
m_output.Write(encoded.buf, encoded.off, encoded.len);
//}
//catch (InterruptedIOException e)
//{
// throw new TlsFatalAlert(AlertDescription.internal_error, e);
//}
m_output.Flush();
}
#endif
///
internal void Close()
{
m_inputRecord.Reset();
ExceptionDispatchInfo io = null;
try
{
m_input.Dispose();
}
catch (IOException e)
{
io = ExceptionDispatchInfo.Capture(e);
}
try
{
m_output.Dispose();
}
catch (IOException e)
{
if (io == null)
{
io = ExceptionDispatchInfo.Capture(e);
}
else
{
// TODO[tls] Available from JDK 7
//io.addSuppressed(e);
}
}
io?.Throw();
}
///
private void CheckChangeCipherSpec(byte[] buf, int off, int len)
{
if (1 != len || (byte)ChangeCipherSpec.change_cipher_spec != buf[off])
{
throw new TlsFatalAlert(AlertDescription.unexpected_message,
"Malformed " + ContentType.GetText(ContentType.change_cipher_spec));
}
}
///
private short CheckRecordType(byte[] buf, int off)
{
short recordType = TlsUtilities.ReadUint8(buf, off);
if (null != m_readCipherDeferred && recordType == ContentType.application_data)
{
this.m_readCipher = m_readCipherDeferred;
this.m_readCipherDeferred = null;
this.m_ciphertextLimit = m_readCipher.GetCiphertextDecodeLimit(m_plaintextLimit);
m_readSeqNo.Reset();
}
else if (m_readCipher.UsesOpaqueRecordType)
{
if (ContentType.application_data != recordType)
{
if (m_ignoreChangeCipherSpec && ContentType.change_cipher_spec == recordType)
{
// See RFC 8446 D.4.
}
else
{
throw new TlsFatalAlert(AlertDescription.unexpected_message,
"Opaque " + ContentType.GetText(recordType));
}
}
}
else
{
switch (recordType)
{
case ContentType.application_data:
{
if (!m_handler.IsApplicationDataReady)
{
throw new TlsFatalAlert(AlertDescription.unexpected_message,
"Not ready for " + ContentType.GetText(ContentType.application_data));
}
break;
}
case ContentType.alert:
case ContentType.change_cipher_spec:
case ContentType.handshake:
// case ContentType.heartbeat:
break;
default:
throw new TlsFatalAlert(AlertDescription.unexpected_message,
"Unsupported " + ContentType.GetText(recordType));
}
}
return recordType;
}
///
private static void CheckLength(int length, int limit, short alertDescription)
{
if (length > limit)
throw new TlsFatalAlert(alertDescription);
}
private sealed class Record
{
private readonly byte[] m_header = new byte[RecordFormat.FragmentOffset];
internal volatile byte[] m_buf;
internal volatile int m_pos;
internal Record()
{
this.m_buf = m_header;
this.m_pos = 0;
}
///
internal void FillTo(Stream input, int length)
{
while (m_pos < length)
{
// TODO[tls-port] Can we support interrupted IO on .NET?
//try
//{
int numRead = input.Read(m_buf, m_pos, length - m_pos);
if (numRead < 1)
break;
m_pos += numRead;
//}
//catch (InterruptedIOException e)
//{
// /*
// * Although modifying the bytesTransferred doesn't seem ideal, it's the simplest
// * way to make sure we don't break client code that depends on the exact type,
// * e.g. in Apache's httpcomponents-core-4.4.9, BHttpConnectionBase.isStale
// * depends on the exception type being SocketTimeoutException (or a subclass).
// *
// * We can set to 0 here because the only relevant callstack (via
// * TlsProtocol.readApplicationData) only ever processes one non-empty record (so
// * interruption after partial output cannot occur).
// */
// m_pos += e.bytesTransferred;
// e.bytesTransferred = 0;
// throw e;
//}
}
}
///
internal void ReadFragment(Stream input, int fragmentLength)
{
int recordLength = RecordFormat.FragmentOffset + fragmentLength;
Resize(recordLength);
FillTo(input, recordLength);
if (m_pos < recordLength)
throw new EndOfStreamException();
}
///
internal bool ReadHeader(Stream input)
{
FillTo(input, RecordFormat.FragmentOffset);
if (m_pos == 0)
return false;
if (m_pos < RecordFormat.FragmentOffset)
throw new EndOfStreamException();
return true;
}
internal void Reset()
{
m_buf = m_header;
m_pos = 0;
}
private void Resize(int length)
{
if (m_buf.Length < length)
{
byte[] tmp = new byte[length];
Array.Copy(m_buf, 0, tmp, 0, m_pos);
m_buf = tmp;
}
}
}
private sealed class SequenceNumber
{
private long m_value = 0L;
private bool m_exhausted = false;
internal long CurrentValue
{
get { lock (this) return m_value; }
}
///
internal long NextValue(short alertDescription)
{
lock (this)
{
if (m_exhausted)
throw new TlsFatalAlert(alertDescription, "Sequence numbers exhausted");
long result = m_value;
if (++m_value == 0L)
{
this.m_exhausted = true;
}
return result;
}
}
internal void Reset()
{
lock (this)
{
this.m_value = 0L;
this.m_exhausted = false;
}
}
}
}
}