From 95db89f0bcf07e49ed86b235f3953718a50b6f54 Mon Sep 17 00:00:00 2001 From: Peter Dettman Date: Thu, 9 Mar 2023 16:17:26 +0700 Subject: Refactoring around Stream usage --- crypto/src/util/io/Streams.cs | 169 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 159 insertions(+), 10 deletions(-) (limited to 'crypto/src/util/io/Streams.cs') diff --git a/crypto/src/util/io/Streams.cs b/crypto/src/util/io/Streams.cs index da8f01068..b975d03bd 100644 --- a/crypto/src/util/io/Streams.cs +++ b/crypto/src/util/io/Streams.cs @@ -1,15 +1,78 @@ using System; using System.IO; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; namespace Org.BouncyCastle.Utilities.IO { - public static class Streams + public static class Streams { - private const int BufferSize = 4096; + private static readonly int MaxStackAlloc = Environment.Is64BitProcess ? 4096 : 1024; - public static void Drain(Stream inStr) + public static int DefaultBufferSize => MaxStackAlloc; + + public static void CopyTo(Stream source, Stream destination) + { + CopyTo(source, destination, DefaultBufferSize); + } + + public static void CopyTo(Stream source, Stream destination, int bufferSize) + { + int bytesRead; +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + Span buffer = bufferSize <= MaxStackAlloc + ? stackalloc byte[bufferSize] + : new byte[bufferSize]; + while ((bytesRead = source.Read(buffer)) != 0) + { + destination.Write(buffer[..bytesRead]); + } +#else + byte[] buffer = new byte[bufferSize]; + while ((bytesRead = source.Read(buffer, 0, buffer.Length)) != 0) + { + destination.Write(buffer, 0, bytesRead); + } +#endif + } + + public static Task CopyToAsync(Stream source, Stream destination) + { + return CopyToAsync(source, destination, DefaultBufferSize); + } + + public static Task CopyToAsync(Stream source, Stream destination, int bufferSize) + { + return CopyToAsync(source, destination, bufferSize, CancellationToken.None); + } + + public static Task CopyToAsync(Stream source, Stream destination, CancellationToken cancellationToken) + { + return CopyToAsync(source, destination, DefaultBufferSize, cancellationToken); + } + + public static async Task CopyToAsync(Stream source, Stream destination, int bufferSize, + CancellationToken cancellationToken) + { + int bytesRead; + byte[] buffer = new byte[bufferSize]; +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + while ((bytesRead = await ReadAsync(source, new Memory(buffer), cancellationToken).ConfigureAwait(false)) != 0) + { + await WriteAsync(destination, new ReadOnlyMemory(buffer, 0, bytesRead), cancellationToken).ConfigureAwait(false); + } +#else + while ((bytesRead = await ReadAsync(source, buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false)) != 0) + { + await WriteAsync(destination, buffer, 0, bytesRead, cancellationToken).ConfigureAwait(false); + } +#endif + } + + public static void Drain(Stream inStr) { - inStr.CopyTo(Stream.Null, BufferSize); + CopyTo(inStr, Stream.Null, DefaultBufferSize); } /// Write the full contents of inStr to the destination stream outStr. @@ -18,7 +81,7 @@ namespace Org.BouncyCastle.Utilities.IO /// In case of IO failure. public static void PipeAll(Stream inStr, Stream outStr) { - inStr.CopyTo(outStr, BufferSize); + PipeAll(inStr, outStr, DefaultBufferSize); } /// Write the full contents of inStr to the destination stream outStr. @@ -28,7 +91,7 @@ namespace Org.BouncyCastle.Utilities.IO /// In case of IO failure. public static void PipeAll(Stream inStr, Stream outStr, int bufferSize) { - inStr.CopyTo(outStr, bufferSize); + CopyTo(inStr, outStr, bufferSize); } /// @@ -48,12 +111,17 @@ namespace Org.BouncyCastle.Utilities.IO /// public static long PipeAllLimited(Stream inStr, long limit, Stream outStr) { - var limited = new LimitedInputStream(inStr, limit); - limited.CopyTo(outStr, BufferSize); - return limit - limited.CurrentLimit; + return PipeAllLimited(inStr, limit, outStr, DefaultBufferSize); } - public static byte[] ReadAll(Stream inStr) + public static long PipeAllLimited(Stream inStr, long limit, Stream outStr, int bufferSize) + { + var limited = new LimitedInputStream(inStr, limit); + CopyTo(limited, outStr, bufferSize); + return limit - limited.CurrentLimit; + } + + public static byte[] ReadAll(Stream inStr) { MemoryStream buf = new MemoryStream(); PipeAll(inStr, buf); @@ -72,6 +140,48 @@ namespace Org.BouncyCastle.Utilities.IO return buf.ToArray(); } + public static Task ReadAsync(Stream source, byte[] buffer, int offset, int count) + { + return source.ReadAsync(buffer, offset, count); + } + + public static Task ReadAsync(Stream source, byte[] buffer, int offset, int count, + CancellationToken cancellationToken) + { + return source.ReadAsync(buffer, offset, count, cancellationToken); + } + +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + public static ValueTask ReadAsync(Stream source, Memory buffer, + CancellationToken cancellationToken = default) + { + if (MemoryMarshal.TryGetArray(buffer, out ArraySegment array)) + { + return new ValueTask( + ReadAsync(source, array.Array!, array.Offset, array.Count, cancellationToken)); + } + + byte[] sharedBuffer = new byte[buffer.Length]; + var readTask = ReadAsync(source, sharedBuffer, 0, buffer.Length, cancellationToken); + return FinishReadAsync(readTask, sharedBuffer, buffer); + } + + private static async ValueTask FinishReadAsync(Task readTask, byte[] localBuffer, + Memory localDestination) + { + try + { + int result = await readTask.ConfigureAwait(false); + new ReadOnlySpan(localBuffer, 0, result).CopyTo(localDestination.Span); + return result; + } + finally + { + Array.Fill(localBuffer, 0x00); + } + } +#endif + public static int ReadFully(Stream inStr, byte[] buf) { return ReadFully(inStr, buf, 0, buf.Length); @@ -117,6 +227,45 @@ namespace Org.BouncyCastle.Utilities.IO throw new ArgumentOutOfRangeException("count"); } + public static Task WriteAsync(Stream destination, byte[] buffer, int offset, int count) + { + return destination.WriteAsync(buffer, offset, count); + } + + public static Task WriteAsync(Stream destination, byte[] buffer, int offset, int count, + CancellationToken cancellationToken) + { + return destination.WriteAsync(buffer, offset, count, cancellationToken); + } + +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + public static ValueTask WriteAsync(Stream destination, ReadOnlyMemory buffer, + CancellationToken cancellationToken = default) + { + if (MemoryMarshal.TryGetArray(buffer, out ArraySegment array)) + { + return new ValueTask( + WriteAsync(destination, array.Array!, array.Offset, array.Count, cancellationToken)); + } + + byte[] sharedBuffer = buffer.ToArray(); + var writeTask = WriteAsync(destination, sharedBuffer, 0, buffer.Length, cancellationToken); + return new ValueTask(FinishWriteAsync(writeTask, sharedBuffer)); + } + + private static async Task FinishWriteAsync(Task writeTask, byte[] localBuffer) + { + try + { + await writeTask.ConfigureAwait(false); + } + finally + { + Array.Fill(localBuffer, 0x00); + } + } +#endif + /// public static int WriteBufTo(MemoryStream buf, byte[] output, int offset) { -- cgit 1.4.1