diff options
author | Peter Dettman <peter.dettman@bouncycastle.org> | 2015-10-16 19:18:37 +0700 |
---|---|---|
committer | Peter Dettman <peter.dettman@bouncycastle.org> | 2015-10-16 19:18:37 +0700 |
commit | bba97435f9810d9d8c8e02ff018ecfb87148553a (patch) | |
tree | 6fdd6a3922b0b2d8b2be91f689bf406f0902bca4 | |
parent | Refactoring (diff) | |
download | BouncyCastle.NET-ed25519-bba97435f9810d9d8c8e02ff018ecfb87148553a.tar.xz |
Port of non-blocking TLS API from Java
-rw-r--r-- | crypto/crypto.csproj | 15 | ||||
-rw-r--r-- | crypto/src/crypto/tls/ByteQueueStream.cs | 114 | ||||
-rw-r--r-- | crypto/src/crypto/tls/RecordStream.cs | 25 | ||||
-rw-r--r-- | crypto/src/crypto/tls/TlsClientProtocol.cs | 48 | ||||
-rw-r--r-- | crypto/src/crypto/tls/TlsProtocol.cs | 182 | ||||
-rw-r--r-- | crypto/src/crypto/tls/TlsServerProtocol.cs | 50 | ||||
-rw-r--r-- | crypto/test/src/crypto/tls/test/ByteQueueStreamTest.cs | 134 | ||||
-rw-r--r-- | crypto/test/src/crypto/tls/test/TlsProtocolNonBlockingTest.cs | 126 |
8 files changed, 665 insertions, 29 deletions
diff --git a/crypto/crypto.csproj b/crypto/crypto.csproj index 7f4131f4e..df7df9f5a 100644 --- a/crypto/crypto.csproj +++ b/crypto/crypto.csproj @@ -4469,6 +4469,11 @@ BuildAction = "Compile" /> <File + RelPath = "src\crypto\tls\ByteQueueStream.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File RelPath = "src\crypto\tls\CertChainType.cs" SubType = "Code" BuildAction = "Compile" @@ -11615,6 +11620,11 @@ BuildAction = "Compile" /> <File + RelPath = "test\src\crypto\tls\test\ByteQueueStreamTest.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File RelPath = "test\src\crypto\tls\test\DtlsProtocolTest.cs" SubType = "Code" BuildAction = "Compile" @@ -11705,6 +11715,11 @@ BuildAction = "Compile" /> <File + RelPath = "test\src\crypto\tls\test\TlsProtocolNonBlockingTest.cs" + SubType = "Code" + BuildAction = "Compile" + /> + <File RelPath = "test\src\crypto\tls\test\TlsPskProtocolTest.cs" SubType = "Code" BuildAction = "Compile" diff --git a/crypto/src/crypto/tls/ByteQueueStream.cs b/crypto/src/crypto/tls/ByteQueueStream.cs new file mode 100644 index 000000000..bf603e006 --- /dev/null +++ b/crypto/src/crypto/tls/ByteQueueStream.cs @@ -0,0 +1,114 @@ +using System; +using System.IO; + +namespace Org.BouncyCastle.Crypto.Tls +{ + public class ByteQueueStream + : Stream + { + private readonly ByteQueue buffer; + + public ByteQueueStream() + { + this.buffer = new ByteQueue(); + } + + public virtual int Available + { + get { return buffer.Available; } + } + + public override bool CanRead + { + get { return true; } + } + + public override bool CanSeek + { + get { return false; } + } + + public override bool CanWrite + { + get { return true; } + } + + public override void Close() + { + } + + public override void Flush() + { + } + + public override long Length + { + get { throw new NotSupportedException(); } + } + + public virtual int Peek(byte[] buf) + { + int bytesToRead = System.Math.Min(buffer.Available, buf.Length); + buffer.Read(buf, 0, bytesToRead, 0); + return bytesToRead; + } + + public override long Position + { + get { throw new NotSupportedException(); } + set { throw new NotSupportedException(); } + } + + public virtual int Read(byte[] buf) + { + return Read(buf, 0, buf.Length); + } + + public override int Read(byte[] buf, int off, int len) + { + int bytesToRead = System.Math.Min(buffer.Available, len); + buffer.RemoveData(buf, off, bytesToRead, 0); + return bytesToRead; + } + + public override int ReadByte() + { + if (buffer.Available == 0) + return -1; + + return buffer.RemoveData(1, 0)[0] & 0xFF; + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public virtual int Skip(int n) + { + int bytesToSkip = System.Math.Min(buffer.Available, n); + buffer.RemoveData(bytesToSkip); + return bytesToSkip; + } + + public virtual void Write(byte[] buf) + { + buffer.AddData(buf, 0, buf.Length); + } + + public override void Write(byte[] buf, int off, int len) + { + buffer.AddData(buf, off, len); + } + + public override void WriteByte(byte b) + { + buffer.AddData(new byte[]{ b }, 0, 1); + } + } +} diff --git a/crypto/src/crypto/tls/RecordStream.cs b/crypto/src/crypto/tls/RecordStream.cs index db5b158bc..6f3fc41c6 100644 --- a/crypto/src/crypto/tls/RecordStream.cs +++ b/crypto/src/crypto/tls/RecordStream.cs @@ -8,6 +8,11 @@ namespace Org.BouncyCastle.Crypto.Tls { private const int DEFAULT_PLAINTEXT_LIMIT = (1 << 14); + internal const int TLS_HEADER_SIZE = 5; + internal const int TLS_HEADER_TYPE_OFFSET = 0; + internal const int TLS_HEADER_VERSION_OFFSET = 1; + internal const int TLS_HEADER_LENGTH_OFFSET = 3; + private TlsProtocol mHandler; private Stream mInput; private Stream mOutput; @@ -116,11 +121,11 @@ namespace Org.BouncyCastle.Crypto.Tls internal virtual bool ReadRecord() { - byte[] recordHeader = TlsUtilities.ReadAllOrNothing(5, mInput); + byte[] recordHeader = TlsUtilities.ReadAllOrNothing(TLS_HEADER_SIZE, mInput); if (recordHeader == null) return false; - byte type = TlsUtilities.ReadUint8(recordHeader, 0); + byte type = TlsUtilities.ReadUint8(recordHeader, TLS_HEADER_TYPE_OFFSET); /* * RFC 5246 6. If a TLS implementation receives an unexpected record type, it MUST send an @@ -130,13 +135,13 @@ namespace Org.BouncyCastle.Crypto.Tls if (!mRestrictReadVersion) { - int version = TlsUtilities.ReadVersionRaw(recordHeader, 1); + int version = TlsUtilities.ReadVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET); if ((version & 0xffffff00) != 0x0300) throw new TlsFatalAlert(AlertDescription.illegal_parameter); } else { - ProtocolVersion version = TlsUtilities.ReadVersion(recordHeader, 1); + ProtocolVersion version = TlsUtilities.ReadVersion(recordHeader, TLS_HEADER_VERSION_OFFSET); if (mReadVersion == null) { mReadVersion = version; @@ -147,7 +152,7 @@ namespace Org.BouncyCastle.Crypto.Tls } } - int length = TlsUtilities.ReadUint16(recordHeader, 3); + int length = TlsUtilities.ReadUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET); byte[] plaintext = DecodeAndVerify(type, mInput, length); mHandler.ProcessRecord(type, plaintext, 0, plaintext.Length); return true; @@ -247,11 +252,11 @@ namespace Org.BouncyCastle.Crypto.Tls */ CheckLength(ciphertext.Length, mCiphertextLimit, AlertDescription.internal_error); - byte[] record = new byte[ciphertext.Length + 5]; - TlsUtilities.WriteUint8(type, record, 0); - TlsUtilities.WriteVersion(mWriteVersion, record, 1); - TlsUtilities.WriteUint16(ciphertext.Length, record, 3); - Array.Copy(ciphertext, 0, record, 5, ciphertext.Length); + byte[] record = new byte[ciphertext.Length + TLS_HEADER_SIZE]; + TlsUtilities.WriteUint8(type, record, TLS_HEADER_TYPE_OFFSET); + TlsUtilities.WriteVersion(mWriteVersion, record, TLS_HEADER_VERSION_OFFSET); + TlsUtilities.WriteUint16(ciphertext.Length, record, TLS_HEADER_LENGTH_OFFSET); + Array.Copy(ciphertext, 0, record, TLS_HEADER_SIZE, ciphertext.Length); mOutput.Write(record, 0, record.Length); mOutput.Flush(); } diff --git a/crypto/src/crypto/tls/TlsClientProtocol.cs b/crypto/src/crypto/tls/TlsClientProtocol.cs index 7b8439acc..14c1cf4a4 100644 --- a/crypto/src/crypto/tls/TlsClientProtocol.cs +++ b/crypto/src/crypto/tls/TlsClientProtocol.cs @@ -21,21 +21,56 @@ namespace Org.BouncyCastle.Crypto.Tls protected CertificateStatus mCertificateStatus = null; protected CertificateRequest mCertificateRequest = null; + /** + * Constructor for blocking mode. + * @param stream The bi-directional stream of data to/from the server + * @param secureRandom Random number generator for various cryptographic functions + */ public TlsClientProtocol(Stream stream, SecureRandom secureRandom) - : base(stream, secureRandom) + : base(stream, secureRandom) { } + /** + * Constructor for blocking mode. + * @param input The stream of data from the server + * @param output The stream of data to the server + * @param secureRandom Random number generator for various cryptographic functions + */ public TlsClientProtocol(Stream input, Stream output, SecureRandom secureRandom) - : base(input, output, secureRandom) + : base(input, output, secureRandom) + { + } + + /** + * Constructor for non-blocking mode.<br> + * <br> + * When data is received, use {@link #offerInput(java.nio.ByteBuffer)} to + * provide the received ciphertext, then use + * {@link #readInput(byte[], int, int)} to read the corresponding cleartext.<br> + * <br> + * Similarly, when data needs to be sent, use + * {@link #offerOutput(byte[], int, int)} to provide the cleartext, then use + * {@link #readOutput(byte[], int, int)} to get the corresponding + * ciphertext. + * + * @param secureRandom + * Random number generator for various cryptographic functions + */ + public TlsClientProtocol(SecureRandom secureRandom) + : base(secureRandom) { } /** - * Initiates a TLS handshake in the role of client + * Initiates a TLS handshake in the role of client.<br> + * <br> + * In blocking mode, this will not return until the handshake is complete. + * In non-blocking mode, use {@link TlsPeer#NotifyHandshakeComplete()} to + * receive a callback when the handshake is complete. * * @param tlsClient The {@link TlsClient} to use for the handshake. - * @throws IOException If handshake was not successful. + * @throws IOException If in blocking mode and handshake was not successful. */ public virtual void Connect(TlsClient tlsClient) { @@ -71,7 +106,7 @@ namespace Org.BouncyCastle.Crypto.Tls SendClientHelloMessage(); this.mConnectionState = CS_CLIENT_HELLO; - CompleteHandshake(); + BlockForHandshake(); } protected override void CleanupHandshake() @@ -116,6 +151,7 @@ namespace Org.BouncyCastle.Crypto.Tls this.mConnectionState = CS_CLIENT_FINISHED; this.mConnectionState = CS_END; + CompleteHandshake(); return; } @@ -208,6 +244,8 @@ namespace Org.BouncyCastle.Crypto.Tls ProcessFinishedMessage(buf); this.mConnectionState = CS_SERVER_FINISHED; this.mConnectionState = CS_END; + + CompleteHandshake(); break; } default: diff --git a/crypto/src/crypto/tls/TlsProtocol.cs b/crypto/src/crypto/tls/TlsProtocol.cs index 8eb7beb3f..7acc34d3c 100644 --- a/crypto/src/crypto/tls/TlsProtocol.cs +++ b/crypto/src/crypto/tls/TlsProtocol.cs @@ -72,6 +72,10 @@ namespace Org.BouncyCastle.Crypto.Tls protected bool mAllowCertificateStatus = false; protected bool mExpectSessionTicket = false; + protected bool mBlocking = true; + protected ByteQueueStream mInputBuffers = null; + protected ByteQueueStream mOutputBuffer = null; + public TlsProtocol(Stream stream, SecureRandom secureRandom) : this(stream, stream, secureRandom) { @@ -83,6 +87,15 @@ namespace Org.BouncyCastle.Crypto.Tls this.mSecureRandom = secureRandom; } + public TlsProtocol(SecureRandom secureRandom) + { + this.mBlocking = false; + this.mInputBuffers = new ByteQueueStream(); + this.mOutputBuffer = new ByteQueueStream(); + this.mRecordStream = new RecordStream(this, mInputBuffers, mOutputBuffer); + this.mSecureRandom = secureRandom; + } + protected abstract TlsContext Context { get; } internal abstract AbstractTlsContext ContextAdmin { get; } @@ -140,13 +153,10 @@ namespace Org.BouncyCastle.Crypto.Tls this.mExpectSessionTicket = false; } - protected virtual void CompleteHandshake() + protected virtual void BlockForHandshake() { - try + if (mBlocking) { - /* - * We will now read data, until we have completed the handshake. - */ while (this.mConnectionState != CS_END) { if (this.mClosed) @@ -156,7 +166,13 @@ namespace Org.BouncyCastle.Crypto.Tls SafeReadRecord(); } + } + } + protected virtual void CompleteHandshake() + { + try + { this.mRecordStream.FinaliseHandshake(); this.mSplitApplicationDataRecords = !TlsUtilities.IsTlsV11(Context); @@ -168,7 +184,10 @@ namespace Org.BouncyCastle.Crypto.Tls { this.mAppDataReady = true; - this.mTlsStream = new TlsStream(this); + if (mBlocking) + { + this.mTlsStream = new TlsStream(this); + } } if (this.mTlsSession != null) @@ -573,9 +592,156 @@ namespace Org.BouncyCastle.Crypto.Tls } /// <summary>The secure bidirectional stream for this connection</summary> + /// <remarks>Only allowed in blocking mode.</remarks> public virtual Stream Stream { - get { return this.mTlsStream; } + get + { + if (!mBlocking) + throw new InvalidOperationException("Cannot use Stream in non-blocking mode! Use OfferInput()/OfferOutput() instead."); + return this.mTlsStream; + } + } + + /** + * Offer input from an arbitrary source. Only allowed in non-blocking mode.<br> + * <br> + * After this method returns, the input buffer is "owned" by this object. Other code + * must not attempt to do anything with it.<br> + * <br> + * This method will decrypt and process all records that are fully available. + * If only part of a record is available, the buffer will be retained until the + * remainder of the record is offered.<br> + * <br> + * If any records containing application data were processed, the decrypted data + * can be obtained using {@link #readInput(byte[], int, int)}. If any records + * containing protocol data were processed, a response may have been generated. + * You should always check to see if there is any available output after calling + * this method by calling {@link #getAvailableOutputBytes()}. + * @param input The input buffer to offer + * @throws IOException If an error occurs while decrypting or processing a record + */ + public virtual void OfferInput(byte[] input) + { + if (mBlocking) + throw new InvalidOperationException("Cannot use OfferInput() in blocking mode! Use Stream instead."); + if (mClosed) + throw new IOException("Connection is closed, cannot accept any more input"); + + mInputBuffers.Write(input); + + // loop while there are enough bytes to read the length of the next record + while (mInputBuffers.Available >= RecordStream.TLS_HEADER_SIZE) + { + byte[] header = new byte[RecordStream.TLS_HEADER_SIZE]; + mInputBuffers.Peek(header); + + int totalLength = TlsUtilities.ReadUint16(header, RecordStream.TLS_HEADER_LENGTH_OFFSET) + RecordStream.TLS_HEADER_SIZE; + if (mInputBuffers.Available < totalLength) + { + // not enough bytes to read a whole record + break; + } + + SafeReadRecord(); + } + } + + /** + * Gets the amount of received application data. A call to {@link #readInput(byte[], int, int)} + * is guaranteed to be able to return at least this much data.<br> + * <br> + * Only allowed in non-blocking mode. + * @return The number of bytes of available application data + */ + public virtual int GetAvailableInputBytes() + { + if (mBlocking) + throw new InvalidOperationException("Cannot use GetAvailableInputBytes() in blocking mode! Use ApplicationDataAvailable() instead."); + + return ApplicationDataAvailable(); + } + + /** + * Retrieves received application data. Use {@link #getAvailableInputBytes()} to check + * how much application data is currently available. This method functions similarly to + * {@link InputStream#read(byte[], int, int)}, except that it never blocks. If no data + * is available, nothing will be copied and zero will be returned.<br> + * <br> + * Only allowed in non-blocking mode. + * @param buffer The buffer to hold the application data + * @param offset The start offset in the buffer at which the data is written + * @param length The maximum number of bytes to read + * @return The total number of bytes copied to the buffer. May be less than the + * length specified if the length was greater than the amount of available data. + */ + public virtual int ReadInput(byte[] buffer, int offset, int length) + { + if (mBlocking) + throw new InvalidOperationException("Cannot use ReadInput() in blocking mode! Use Stream instead."); + + return ReadApplicationData(buffer, offset, System.Math.Min(length, ApplicationDataAvailable())); + } + + /** + * Offer output from an arbitrary source. Only allowed in non-blocking mode.<br> + * <br> + * After this method returns, the specified section of the buffer will have been + * processed. Use {@link #readOutput(byte[], int, int)} to get the bytes to + * transmit to the other peer.<br> + * <br> + * This method must not be called until after the handshake is complete! Attempting + * to call it before the handshake is complete will result in an exception. + * @param buffer The buffer containing application data to encrypt + * @param offset The offset at which to begin reading data + * @param length The number of bytes of data to read + * @throws IOException If an error occurs encrypting the data, or the handshake is not complete + */ + public virtual void OfferOutput(byte[] buffer, int offset, int length) + { + if (mBlocking) + throw new InvalidOperationException("Cannot use OfferOutput() in blocking mode! Use Stream instead."); + if (!mAppDataReady) + throw new IOException("Application data cannot be sent until the handshake is complete!"); + + WriteData(buffer, offset, length); + } + + /** + * Gets the amount of encrypted data available to be sent. A call to + * {@link #readOutput(byte[], int, int)} is guaranteed to be able to return at + * least this much data.<br> + * <br> + * Only allowed in non-blocking mode. + * @return The number of bytes of available encrypted data + */ + public virtual int GetAvailableOutputBytes() + { + if (mBlocking) + throw new InvalidOperationException("Cannot use GetAvailableOutputBytes() in blocking mode! Use Stream instead."); + + return mOutputBuffer.Available; + } + + /** + * Retrieves encrypted data to be sent. Use {@link #getAvailableOutputBytes()} to check + * how much encrypted data is currently available. This method functions similarly to + * {@link InputStream#read(byte[], int, int)}, except that it never blocks. If no data + * is available, nothing will be copied and zero will be returned.<br> + * <br> + * Only allowed in non-blocking mode. + * @param buffer The buffer to hold the encrypted data + * @param offset The start offset in the buffer at which the data is written + * @param length The maximum number of bytes to read + * @return The total number of bytes copied to the buffer. May be less than the + * length specified if the length was greater than the amount of available data. + */ + public virtual int ReadOutput(byte[] buffer, int offset, int length) + { + if (mBlocking) + throw new InvalidOperationException("Cannot use ReadOutput() in blocking mode! Use Stream instead."); + + return mOutputBuffer.Read(buffer, offset, length); } /** @@ -764,7 +930,7 @@ namespace Org.BouncyCastle.Crypto.Tls mRecordStream.Flush(); } - protected internal virtual bool IsClosed + public virtual bool IsClosed { get { return mClosed; } } diff --git a/crypto/src/crypto/tls/TlsServerProtocol.cs b/crypto/src/crypto/tls/TlsServerProtocol.cs index b73cb5a30..27f7a1dfd 100644 --- a/crypto/src/crypto/tls/TlsServerProtocol.cs +++ b/crypto/src/crypto/tls/TlsServerProtocol.cs @@ -22,21 +22,57 @@ namespace Org.BouncyCastle.Crypto.Tls protected short mClientCertificateType = -1; protected TlsHandshakeHash mPrepareFinishHash = null; + /** + * Constructor for blocking mode. + * @param stream The bi-directional stream of data to/from the client + * @param output The stream of data to the client + * @param secureRandom Random number generator for various cryptographic functions + */ public TlsServerProtocol(Stream stream, SecureRandom secureRandom) - : base(stream, secureRandom) + : base(stream, secureRandom) { } + /** + * Constructor for blocking mode. + * @param input The stream of data from the client + * @param output The stream of data to the client + * @param secureRandom Random number generator for various cryptographic functions + */ public TlsServerProtocol(Stream input, Stream output, SecureRandom secureRandom) - : base(input, output, secureRandom) + : base(input, output, secureRandom) + { + } + + /** + * Constructor for non-blocking mode.<br> + * <br> + * When data is received, use {@link #offerInput(java.nio.ByteBuffer)} to + * provide the received ciphertext, then use + * {@link #readInput(byte[], int, int)} to read the corresponding cleartext.<br> + * <br> + * Similarly, when data needs to be sent, use + * {@link #offerOutput(byte[], int, int)} to provide the cleartext, then use + * {@link #readOutput(byte[], int, int)} to get the corresponding + * ciphertext. + * + * @param secureRandom + * Random number generator for various cryptographic functions + */ + public TlsServerProtocol(SecureRandom secureRandom) + : base(secureRandom) { } /** - * Receives a TLS handshake in the role of server + * Receives a TLS handshake in the role of server.<br> + * <br> + * In blocking mode, this will not return until the handshake is complete. + * In non-blocking mode, use {@link TlsPeer#notifyHandshakeComplete()} to + * receive a callback when the handshake is complete. * - * @param mTlsServer - * @throws IOException If handshake was not successful. + * @param tlsServer + * @throws IOException If in blocking mode and handshake was not successful. */ public virtual void Accept(TlsServer tlsServer) { @@ -60,7 +96,7 @@ namespace Org.BouncyCastle.Crypto.Tls this.mRecordStream.SetRestrictReadVersion(false); - CompleteHandshake(); + BlockForHandshake(); } protected override void CleanupHandshake() @@ -329,6 +365,8 @@ namespace Org.BouncyCastle.Crypto.Tls SendFinishedMessage(); this.mConnectionState = CS_SERVER_FINISHED; this.mConnectionState = CS_END; + + CompleteHandshake(); break; } default: diff --git a/crypto/test/src/crypto/tls/test/ByteQueueStreamTest.cs b/crypto/test/src/crypto/tls/test/ByteQueueStreamTest.cs new file mode 100644 index 000000000..1d68a5215 --- /dev/null +++ b/crypto/test/src/crypto/tls/test/ByteQueueStreamTest.cs @@ -0,0 +1,134 @@ +using System; + +using NUnit.Framework; + +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Crypto.Tls.Tests +{ + [TestFixture] + public class ByteQueueStreamTest + { + [Test] + public void TestAvailable() + { + ByteQueueStream input = new ByteQueueStream(); + + // buffer is empty + Assert.AreEqual(0, input.Available); + + // after adding once + input.Write(new byte[10]); + Assert.AreEqual(10, input.Available); + + // after adding more than once + input.Write(new byte[5]); + Assert.AreEqual(15, input.Available); + + // after reading a single byte + input.ReadByte(); + Assert.AreEqual(14, input.Available); + + // after reading into a byte array + input.Read(new byte[4]); + Assert.AreEqual(10, input.Available); + + input.Close(); // so compiler doesn't whine about a resource leak + } + + [Test] + public void TestSkip() + { + ByteQueueStream input = new ByteQueueStream(); + + // skip when buffer is empty + Assert.AreEqual(0, input.Skip(10)); + + // skip equal to available + input.Write(new byte[2]); + Assert.AreEqual(2, input.Skip(2)); + Assert.AreEqual(0, input.Available); + + // skip less than available + input.Write(new byte[10]); + Assert.AreEqual(5, input.Skip(5)); + Assert.AreEqual(5, input.Available); + + // skip more than available + Assert.AreEqual(5, input.Skip(20)); + Assert.AreEqual(0, input.Available); + + input.Close();// so compiler doesn't whine about a resource leak + } + + [Test] + public void TestRead() + { + ByteQueueStream input = new ByteQueueStream(); + input.Write(new byte[] { 0x01, 0x02 }); + input.Write(new byte[]{ 0x03 }); + + Assert.AreEqual(0x01, input.ReadByte()); + Assert.AreEqual(0x02, input.ReadByte()); + Assert.AreEqual(0x03, input.ReadByte()); + Assert.AreEqual(-1, input.ReadByte()); + + input.Close(); // so compiler doesn't whine about a resource leak + } + + [Test] + public void TestReadArray() + { + ByteQueueStream input = new ByteQueueStream(); + input.Write(new byte[] { 0x01, 0x02, 0x03, 0x04, 0x05, 0x06 }); + + byte[] buffer = new byte[5]; + + // read less than available into specified position + Assert.AreEqual(1, input.Read(buffer, 2, 1)); + AssertArrayEquals(new byte[]{ 0x00, 0x00, 0x01, 0x00, 0x00 }, buffer); + + // read equal to available + Assert.AreEqual(5, input.Read(buffer)); + AssertArrayEquals(new byte[]{ 0x02, 0x03, 0x04, 0x05, 0x06 }, buffer); + + // read more than available + input.Write(new byte[]{ 0x01, 0x02, 0x03 }); + Assert.AreEqual(3, input.Read(buffer)); + AssertArrayEquals(new byte[]{ 0x01, 0x02, 0x03, 0x05, 0x06 }, buffer); + + input.Close(); // so compiler doesn't whine about a resource leak + } + + [Test] + public void TestPeek() + { + ByteQueueStream input = new ByteQueueStream(); + + byte[] buffer = new byte[5]; + + // peek more than available + Assert.AreEqual(0, input.Peek(buffer)); + AssertArrayEquals(new byte[]{ 0x00, 0x00, 0x00, 0x00, 0x00 }, buffer); + + // peek less than available + input.Write(new byte[]{ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06 }); + Assert.AreEqual(5, input.Peek(buffer)); + AssertArrayEquals(new byte[]{ 0x01, 0x02, 0x03, 0x04, 0x05 }, buffer); + Assert.AreEqual(6, input.Available); + + // peek equal to available + input.ReadByte(); + Assert.AreEqual(5, input.Peek(buffer)); + AssertArrayEquals(new byte[]{ 0x02, 0x03, 0x04, 0x05, 0x06 }, buffer); + Assert.AreEqual(5, input.Available); + + input.Close(); // so compiler doesn't whine about a resource leak + } + + private static void AssertArrayEquals(byte[] a, byte[] b) + { + Assert.IsTrue(Arrays.AreEqual(a, b)); + } + } +} diff --git a/crypto/test/src/crypto/tls/test/TlsProtocolNonBlockingTest.cs b/crypto/test/src/crypto/tls/test/TlsProtocolNonBlockingTest.cs new file mode 100644 index 000000000..5fe0f32ad --- /dev/null +++ b/crypto/test/src/crypto/tls/test/TlsProtocolNonBlockingTest.cs @@ -0,0 +1,126 @@ +using System; +using System.IO; + +using NUnit.Framework; + +using Org.BouncyCastle.Security; +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Crypto.Tls.Tests +{ + [TestFixture] + public class TlsProtocolNonBlockingTest + { + [Test] + public void TestClientServerFragmented() + { + // tests if it's really non-blocking when partial records arrive + DoTestClientServer(true); + } + + [Test] + public void TestClientServerNonFragmented() + { + DoTestClientServer(false); + } + + private static void DoTestClientServer(bool fragment) + { + SecureRandom secureRandom = new SecureRandom(); + + TlsClientProtocol clientProtocol = new TlsClientProtocol(secureRandom); + TlsServerProtocol serverProtocol = new TlsServerProtocol(secureRandom); + + clientProtocol.Connect(new MockTlsClient(null)); + serverProtocol.Accept(new MockTlsServer()); + + // pump handshake + bool hadDataFromServer = true; + bool hadDataFromClient = true; + while (hadDataFromServer || hadDataFromClient) + { + hadDataFromServer = PumpData(serverProtocol, clientProtocol, fragment); + hadDataFromClient = PumpData(clientProtocol, serverProtocol, fragment); + } + + // send data in both directions + byte[] data = new byte[1024]; + secureRandom.NextBytes(data); + WriteAndRead(clientProtocol, serverProtocol, data, fragment); + WriteAndRead(serverProtocol, clientProtocol, data, fragment); + + // close the connection + clientProtocol.Close(); + PumpData(clientProtocol, serverProtocol, fragment); + CheckClosed(serverProtocol); + CheckClosed(clientProtocol); + } + + private static void WriteAndRead(TlsProtocol writer, TlsProtocol reader, byte[] data, bool fragment) + { + int dataSize = data.Length; + writer.OfferOutput(data, 0, dataSize); + PumpData(writer, reader, fragment); + + Assert.AreEqual(dataSize, reader.GetAvailableInputBytes()); + byte[] readData = new byte[dataSize]; + reader.ReadInput(readData, 0, dataSize); + AssertArrayEquals(data, readData); + } + + private static bool PumpData(TlsProtocol from, TlsProtocol to, bool fragment) + { + int byteCount = from.GetAvailableOutputBytes(); + if (byteCount == 0) + { + return false; + } + + if (fragment) + { + while (from.GetAvailableOutputBytes() > 0) + { + byte[] buffer = new byte[1]; + from.ReadOutput(buffer, 0, 1); + to.OfferInput(buffer); + } + } + else + { + byte[] buffer = new byte[byteCount]; + from.ReadOutput(buffer, 0, buffer.Length); + to.OfferInput(buffer); + } + + return true; + } + + private static void CheckClosed(TlsProtocol protocol) + { + Assert.IsTrue(protocol.IsClosed); + + try + { + protocol.OfferInput(new byte[10]); + Assert.Fail("Input was accepted after close"); + } + catch (IOException e) + { + } + + try + { + protocol.OfferOutput(new byte[10], 0, 10); + Assert.Fail("Output was accepted after close"); + } + catch (IOException e) + { + } + } + + private static void AssertArrayEquals(byte[] a, byte[] b) + { + Assert.IsTrue(Arrays.AreEqual(a, b)); + } + } +} |