about summary refs log tree commit diff
path: root/LibMatrix/Extensions
diff options
context:
space:
mode:
Diffstat (limited to 'LibMatrix/Extensions')
-rw-r--r--LibMatrix/Extensions/CanonicalJsonSerializer.cs96
-rw-r--r--LibMatrix/Extensions/EnumerableExtensions.cs85
-rw-r--r--LibMatrix/Extensions/JsonElementExtensions.cs1
-rw-r--r--LibMatrix/Extensions/MatrixHttpClient.Multi.cs11
-rw-r--r--LibMatrix/Extensions/MatrixHttpClient.Single.cs106
-rw-r--r--LibMatrix/Extensions/UnicodeJsonEncoder.cs173
6 files changed, 429 insertions, 43 deletions
diff --git a/LibMatrix/Extensions/CanonicalJsonSerializer.cs b/LibMatrix/Extensions/CanonicalJsonSerializer.cs
new file mode 100644

index 0000000..ae535aa --- /dev/null +++ b/LibMatrix/Extensions/CanonicalJsonSerializer.cs
@@ -0,0 +1,96 @@ +using System.Collections.Frozen; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; +using ArcaneLibs.Extensions; + +namespace LibMatrix.Extensions; + +public static class CanonicalJsonSerializer { + // TODO: Alphabetise dictionaries + private static JsonSerializerOptions JsonOptions => new() { + WriteIndented = false, + Encoder = UnicodeJsonEncoder.Singleton, + }; + + private static readonly FrozenSet<PropertyInfo> JsonSerializerOptionsProperties = typeof(JsonSerializerOptions) + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(x => x.SetMethod != null && x.GetMethod != null) + .ToFrozenSet(); + + private static JsonSerializerOptions MergeOptions(JsonSerializerOptions? inputOptions) { + var newOptions = JsonOptions; + if (inputOptions == null) + return newOptions; + + foreach (var property in JsonSerializerOptionsProperties) { + if(property.Name == nameof(JsonSerializerOptions.Encoder)) + continue; + if (property.Name == nameof(JsonSerializerOptions.WriteIndented)) + continue; + + var value = property.GetValue(inputOptions); + // if (value == null) + // continue; + property.SetValue(newOptions, value); + } + + return newOptions; + } + +#region STJ API + + public static String Serialize<TValue>(TValue value, JsonSerializerOptions? options = null) { + var newOptions = MergeOptions(options); + + return JsonSerializer.SerializeToNode(value, options) // We want to allow passing custom converters for eg. double/float -> string here... + .SortProperties()! + .CanonicalizeNumbers()! + .ToJsonString(newOptions); + + + // System.Text.Json.JsonSerializer.SerializeToNode(System.Text.Json.JsonSerializer.Deserialize<dynamic>("{\n \"a\": -0,\n \"b\": 1e10\n}")).ToJsonString(); + + } + + public static String Serialize(object value, Type inputType, JsonSerializerOptions? options = null) => JsonSerializer.Serialize(value, inputType, JsonOptions); + // public static String Serialize<TValue>(TValue value, JsonTypeInfo<TValue> jsonTypeInfo) => JsonSerializer.Serialize(value, jsonTypeInfo, _options); + // public static String Serialize(Object value, JsonTypeInfo jsonTypeInfo) + + public static byte[] SerializeToUtf8Bytes<T>(T value, JsonSerializerOptions? options = null) { + var newOptions = MergeOptions(null); + return JsonSerializer.SerializeToNode(value, options) // We want to allow passing custom converters for eg. double/float -> string here... + .SortProperties()! + .CanonicalizeNumbers()! + .ToJsonString(newOptions).AsBytes().ToArray(); + } + +#endregion + + // ReSharper disable once UnusedType.Local + private static class JsonExtensions { + public static Action<JsonTypeInfo> AlphabetizeProperties(Type type) { + return typeInfo => { + if (typeInfo.Kind != JsonTypeInfoKind.Object || !type.IsAssignableFrom(typeInfo.Type)) + return; + AlphabetizeProperties()(typeInfo); + }; + } + + public static Action<JsonTypeInfo> AlphabetizeProperties() { + return static typeInfo => { + if (typeInfo.Kind == JsonTypeInfoKind.Dictionary) { } + + if (typeInfo.Kind != JsonTypeInfoKind.Object) + return; + var properties = typeInfo.Properties.OrderBy(p => p.Name, StringComparer.Ordinal).ToList(); + typeInfo.Properties.Clear(); + for (int i = 0; i < properties.Count; i++) { + properties[i].Order = i; + typeInfo.Properties.Add(properties[i]); + } + }; + } + } +} \ No newline at end of file diff --git a/LibMatrix/Extensions/EnumerableExtensions.cs b/LibMatrix/Extensions/EnumerableExtensions.cs
index 42d9491..4dcf26e 100644 --- a/LibMatrix/Extensions/EnumerableExtensions.cs +++ b/LibMatrix/Extensions/EnumerableExtensions.cs
@@ -1,29 +1,86 @@ +using System.Collections.Frozen; +using System.Collections.Immutable; + namespace LibMatrix.Extensions; public static class EnumerableExtensions { public static void MergeStateEventLists(this IList<StateEvent> oldState, IList<StateEvent> newState) { - foreach (var stateEvent in newState) { - var old = oldState.FirstOrDefault(x => x.Type == stateEvent.Type && x.StateKey == stateEvent.StateKey); - if (old is null) { - oldState.Add(stateEvent); - continue; + // foreach (var stateEvent in newState) { + // var old = oldState.FirstOrDefault(x => x.Type == stateEvent.Type && x.StateKey == stateEvent.StateKey); + // if (old is null) { + // oldState.Add(stateEvent); + // continue; + // } + // + // oldState.Remove(old); + // oldState.Add(stateEvent); + // } + + foreach (var e in newState) { + switch (FindIndex(e)) { + case -1: + oldState.Add(e); + break; + case var index: + oldState[index] = e; + break; + } + } + + int FindIndex(StateEvent needle) { + for (int i = 0; i < oldState.Count; i++) { + var old = oldState[i]; + if (old.Type == needle.Type && old.StateKey == needle.StateKey) + return i; } - oldState.Remove(old); - oldState.Add(stateEvent); + return -1; } } public static void MergeStateEventLists(this IList<StateEventResponse> oldState, IList<StateEventResponse> newState) { - foreach (var stateEvent in newState) { - var old = oldState.FirstOrDefault(x => x.Type == stateEvent.Type && x.StateKey == stateEvent.StateKey); - if (old is null) { - oldState.Add(stateEvent); - continue; + foreach (var e in newState) { + switch (FindIndex(e)) { + case -1: + oldState.Add(e); + break; + case var index: + oldState[index] = e; + break; + } + } + + int FindIndex(StateEventResponse needle) { + for (int i = 0; i < oldState.Count; i++) { + var old = oldState[i]; + if (old.Type == needle.Type && old.StateKey == needle.StateKey) + return i; + } + + return -1; + } + } + + public static void MergeStateEventLists(this List<StateEventResponse> oldState, List<StateEventResponse> newState) { + foreach (var e in newState) { + switch (FindIndex(e)) { + case -1: + oldState.Add(e); + break; + case var index: + oldState[index] = e; + break; + } + } + + int FindIndex(StateEventResponse needle) { + for (int i = 0; i < oldState.Count; i++) { + var old = oldState[i]; + if (old.Type == needle.Type && old.StateKey == needle.StateKey) + return i; } - oldState.Remove(old); - oldState.Add(stateEvent); + return -1; } } } \ No newline at end of file diff --git a/LibMatrix/Extensions/JsonElementExtensions.cs b/LibMatrix/Extensions/JsonElementExtensions.cs
index c4ed743..dfec95b 100644 --- a/LibMatrix/Extensions/JsonElementExtensions.cs +++ b/LibMatrix/Extensions/JsonElementExtensions.cs
@@ -126,6 +126,7 @@ public static class JsonElementExtensions { $"Encountered dictionary {field.Name} with key type {keyType.Name} and value type {valueType.Name}!"); return field.Value.EnumerateObject() + // TODO: use key.Value? .Where(key => !valueType.IsPrimitive && valueType != typeof(string)) .Aggregate(false, (current, key) => current | key.FindExtraJsonPropertyFieldsByValueKind(containerType, valueType) diff --git a/LibMatrix/Extensions/MatrixHttpClient.Multi.cs b/LibMatrix/Extensions/MatrixHttpClient.Multi.cs
index e7a2044..bf6fe63 100644 --- a/LibMatrix/Extensions/MatrixHttpClient.Multi.cs +++ b/LibMatrix/Extensions/MatrixHttpClient.Multi.cs
@@ -1,16 +1,5 @@ #define SINGLE_HTTPCLIENT // Use a single HttpClient instance for all MatrixHttpClient instances // #define SYNC_HTTPCLIENT // Only allow one request as a time, for debugging -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Net.Http.Headers; -using System.Reflection; -using System.Security.Cryptography.X509Certificates; -using System.Text; -using System.Text.Json; -using System.Text.Json.Serialization; -using ArcaneLibs; -using ArcaneLibs.Extensions; - namespace LibMatrix.Extensions; public static class HttpClientExtensions { diff --git a/LibMatrix/Extensions/MatrixHttpClient.Single.cs b/LibMatrix/Extensions/MatrixHttpClient.Single.cs
index 39eb7e5..671566f 100644 --- a/LibMatrix/Extensions/MatrixHttpClient.Single.cs +++ b/LibMatrix/Extensions/MatrixHttpClient.Single.cs
@@ -1,20 +1,21 @@ #define SINGLE_HTTPCLIENT // Use a single HttpClient instance for all MatrixHttpClient instances -// #define SYNC_HTTPCLIENT // Only allow one request as a time, for debugging +// #define SYNC_HTTPCLIENT // Only allow one request as a time, for debugging using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Net; using System.Net.Http.Headers; using System.Reflection; -using System.Security.Cryptography.X509Certificates; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using ArcaneLibs; using ArcaneLibs.Extensions; +using LibMatrix.Homeservers.ImplementationDetails.Synapse.Models.Requests; namespace LibMatrix.Extensions; #if SINGLE_HTTPCLIENT -// TODO: Add URI wrapper for +// TODO: Add URI wrapper for public class MatrixHttpClient { private static readonly HttpClient Client; @@ -50,6 +51,7 @@ public class MatrixHttpClient { internal SemaphoreSlim _rateLimitSemaphore { get; } = new(1, 1); #endif + public static bool LogRequests = true; public Dictionary<string, string> AdditionalQueryParameters { get; set; } = new(); public Uri? BaseAddress { get; set; } @@ -59,7 +61,7 @@ public class MatrixHttpClient { typeof(HttpRequestHeaders).GetConstructor(BindingFlags.NonPublic | BindingFlags.Instance, null, [], null)?.Invoke([]) as HttpRequestHeaders ?? throw new InvalidOperationException("Failed to create HttpRequestHeaders"); - private JsonSerializerOptions GetJsonSerializerOptions(JsonSerializerOptions? options = null) { + private static JsonSerializerOptions GetJsonSerializerOptions(JsonSerializerOptions? options = null) { options ??= new JsonSerializerOptions(); options.Converters.Add(new JsonFloatStringConverter()); options.Converters.Add(new JsonDoubleStringConverter()); @@ -68,27 +70,54 @@ public class MatrixHttpClient { return options; } - public async Task<HttpResponseMessage> SendUnhandledAsync(HttpRequestMessage request, CancellationToken cancellationToken) { + public async Task<HttpResponseMessage> SendUnhandledAsync(HttpRequestMessage request, CancellationToken cancellationToken, int attempt = 0) { + if (request.RequestUri is null) throw new NullReferenceException("RequestUri is null"); + // if (!request.RequestUri.IsAbsoluteUri) + request.RequestUri = request.RequestUri.EnsureAbsolute(BaseAddress!); + var swWait = Stopwatch.StartNew(); #if SYNC_HTTPCLIENT await _rateLimitSemaphore.WaitAsync(cancellationToken); #endif - Console.WriteLine($"Sending {request.Method} {BaseAddress}{request.RequestUri} ({Util.BytesToString(request.Content?.Headers.ContentLength ?? 0)})"); - if (request.RequestUri is null) throw new NullReferenceException("RequestUri is null"); - if (!request.RequestUri.IsAbsoluteUri) request.RequestUri = new Uri(BaseAddress, request.RequestUri); + if (!request.RequestUri.IsAbsoluteUri) + request.RequestUri = new Uri(BaseAddress ?? throw new InvalidOperationException("Relative URI passed, but no BaseAddress is specified!"), request.RequestUri); + swWait.Stop(); + var swExec = Stopwatch.StartNew(); + foreach (var (key, value) in AdditionalQueryParameters) request.RequestUri = request.RequestUri.AddQuery(key, value); - foreach (var (key, value) in DefaultRequestHeaders) request.Headers.Add(key, value); + foreach (var (key, value) in DefaultRequestHeaders) { + if (request.Headers.Contains(key)) continue; + request.Headers.Add(key, value); + } request.Options.Set(new HttpRequestOptionsKey<bool>("WebAssemblyEnableStreamingResponse"), true); + if (LogRequests) + Console.WriteLine("Sending " + request.Summarise(includeHeaders: true, includeQuery: true, includeContentIfText: false, hideHeaders: ["Accept"])); + HttpResponseMessage? responseMessage; try { responseMessage = await Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); } catch (Exception e) { - Console.WriteLine( - $"Failed to send request {request.Method} {BaseAddress}{request.RequestUri} ({Util.BytesToString(request.Content?.Headers.ContentLength ?? 0)}):\n{e}"); + // if (attempt >= 5) { + // Console.WriteLine( + // $"Failed to send request {request.Method} {BaseAddress}{request.RequestUri} ({Util.BytesToString(request.Content?.Headers.ContentLength ?? 0)}):\n{e}"); + // throw; + // } + // + // if (e is TaskCanceledException or TimeoutException or HttpRequestException) { + // if (request.Method == HttpMethod.Get && !cancellationToken.IsCancellationRequested) { + // await Task.Delay(Random.Shared.Next(100, 1000), cancellationToken); + // request.ResetSendStatus(); + // return await SendUnhandledAsync(request, cancellationToken, attempt + 1); + // } + // } + // else if (!e.ToString().StartsWith("TypeError: NetworkError")) + // Console.WriteLine( + // $"Failed to send request {request.Method} {BaseAddress}{request.RequestUri} ({Util.BytesToString(request.Content?.Headers.ContentLength ?? 0)}):\n{e}"); + throw; } #if SYNC_HTTPCLIENT @@ -97,8 +126,26 @@ public class MatrixHttpClient { } #endif - Console.WriteLine( - $"Sending {request.Method} {request.RequestUri} ({Util.BytesToString(request.Content?.Headers.ContentLength ?? 0)}) -> {(int)responseMessage.StatusCode} {responseMessage.StatusCode} ({Util.BytesToString(responseMessage.Content.Headers.ContentLength ?? 0)})"); + // Console.WriteLine($"Sending {request.Method} {request.RequestUri} ({Util.BytesToString(request.Content?.Headers.ContentLength ?? 0)}) -> {(int)responseMessage.StatusCode} {responseMessage.StatusCode} ({Util.BytesToString(responseMessage.GetContentLength())}, WAIT={swWait.ElapsedMilliseconds}ms, EXEC={swExec.ElapsedMilliseconds}ms)"); + if (LogRequests) + Console.WriteLine("Received " + responseMessage.Summarise(includeHeaders: true, includeContentIfText: false, hideHeaders: [ + "Server", + "Date", + "Transfer-Encoding", + "Connection", + "Vary", + "Content-Length", + "Access-Control-Allow-Origin", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Headers", + "Access-Control-Expose-Headers", + "Cache-Control", + "Cross-Origin-Resource-Policy", + "X-Content-Security-Policy", + "Referrer-Policy", + "X-Robots-Tag", + "Content-Security-Policy" + ])); return responseMessage; } @@ -107,6 +154,12 @@ public class MatrixHttpClient { var responseMessage = await SendUnhandledAsync(request, cancellationToken); if (responseMessage.IsSuccessStatusCode) return responseMessage; + //retry on gateway timeout + // if (responseMessage.StatusCode == HttpStatusCode.GatewayTimeout) { + // request.ResetSendStatus(); + // return await SendAsync(request, cancellationToken); + // } + //error handling var content = await responseMessage.Content.ReadAsStringAsync(cancellationToken); if (content.Length == 0) @@ -114,10 +167,15 @@ public class MatrixHttpClient { ErrorCode = "M_UNKNOWN", Error = "Unknown error, server returned no content" }; - if (!content.StartsWith('{')) throw new InvalidDataException("Encountered invalid data:\n" + content); + + // if (!content.StartsWith('{')) throw new InvalidDataException("Encountered invalid data:\n" + content); + if (!content.TrimStart().StartsWith('{')) { + responseMessage.EnsureSuccessStatusCode(); + throw new InvalidDataException("Encountered invalid data:\n" + content); + } //we have a matrix error - MatrixException? ex = null; + MatrixException? ex; try { ex = JsonSerializer.Deserialize<MatrixException>(content); } @@ -131,7 +189,7 @@ public class MatrixHttpClient { Debug.Assert(ex != null, nameof(ex) + " != null"); ex.RawContent = content; // Console.WriteLine($"Failed to send request: {ex}"); - if (ex?.RetryAfterMs is null) throw ex!; + if (ex.RetryAfterMs is null) throw ex!; //we have a ratelimit error await Task.Delay(ex.RetryAfterMs.Value, cancellationToken); request.ResetSendStatus(); @@ -170,7 +228,7 @@ public class MatrixHttpClient { } // GetStreamAsync - public new async Task<Stream> GetStreamAsync(string requestUri, CancellationToken cancellationToken = default) { + public async Task<Stream> GetStreamAsync(string requestUri, CancellationToken cancellationToken = default) { var request = new HttpRequestMessage(HttpMethod.Get, requestUri); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); var response = await SendAsync(request, cancellationToken); @@ -209,7 +267,7 @@ public class MatrixHttpClient { await foreach (var resp in result) yield return resp; } - public async Task<bool> CheckSuccessStatus(string url) { + public static async Task<bool> CheckSuccessStatus(string url) { //cors causes failure, try to catch try { var resp = await Client.GetAsync(url); @@ -227,5 +285,17 @@ public class MatrixHttpClient { }; return await SendAsync(request, cancellationToken); } + + public async Task<HttpResponseMessage> DeleteAsync(string url) { + var request = new HttpRequestMessage(HttpMethod.Delete, url); + return await SendAsync(request); + } + + public async Task<HttpResponseMessage> DeleteAsJsonAsync<T>(string url, T payload) { + var request = new HttpRequestMessage(HttpMethod.Delete, url) { + Content = new StringContent(JsonSerializer.Serialize(payload), Encoding.UTF8, "application/json") + }; + return await SendAsync(request); + } } #endif \ No newline at end of file diff --git a/LibMatrix/Extensions/UnicodeJsonEncoder.cs b/LibMatrix/Extensions/UnicodeJsonEncoder.cs new file mode 100644
index 0000000..ae58263 --- /dev/null +++ b/LibMatrix/Extensions/UnicodeJsonEncoder.cs
@@ -0,0 +1,173 @@ +// LibMatrix: File sourced from https://github.com/dotnet/runtime/pull/87147/files under the MIT license. + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text; +using System.Text.Encodings.Web; + +namespace LibMatrix.Extensions; + +internal sealed class UnicodeJsonEncoder : JavaScriptEncoder +{ + internal static readonly UnicodeJsonEncoder Singleton = new UnicodeJsonEncoder(); + + private readonly bool _preferHexEscape; + private readonly bool _preferUppercase; + + public UnicodeJsonEncoder() + : this(preferHexEscape: false, preferUppercase: false) + { + } + + public UnicodeJsonEncoder(bool preferHexEscape, bool preferUppercase) + { + _preferHexEscape = preferHexEscape; + _preferUppercase = preferUppercase; + } + + public override int MaxOutputCharactersPerInputCharacter => 6; // "\uXXXX" for a single char ("\uXXXX\uYYYY" [12 chars] for supplementary scalar value) + + public override unsafe int FindFirstCharacterToEncode(char* text, int textLength) + { + for (int index = 0; index < textLength; ++index) + { + char value = text[index]; + + if (NeedsEncoding(value)) + { + return index; + } + } + + return -1; + } + + public override unsafe bool TryEncodeUnicodeScalar(int unicodeScalar, char* buffer, int bufferLength, out int numberOfCharactersWritten) + { + bool encode = WillEncode(unicodeScalar); + + if (!encode) + { + Span<char> span = new Span<char>(buffer, bufferLength); + int spanWritten; + bool succeeded = new Rune(unicodeScalar).TryEncodeToUtf16(span, out spanWritten); + numberOfCharactersWritten = spanWritten; + return succeeded; + } + + if (!_preferHexEscape && unicodeScalar <= char.MaxValue && HasTwoCharacterEscape((char)unicodeScalar)) + { + if (bufferLength < 2) + { + numberOfCharactersWritten = 0; + return false; + } + + buffer[0] = '\\'; + buffer[1] = GetTwoCharacterEscapeSuffix((char)unicodeScalar); + numberOfCharactersWritten = 2; + return true; + } + else + { + if (bufferLength < 6) + { + numberOfCharactersWritten = 0; + return false; + } + + buffer[0] = '\\'; + buffer[1] = 'u'; + buffer[2] = '0'; + buffer[3] = '0'; + buffer[4] = ToHexDigit((unicodeScalar & 0xf0) >> 4, _preferUppercase); + buffer[5] = ToHexDigit(unicodeScalar & 0xf, _preferUppercase); + numberOfCharactersWritten = 6; + return true; + } + } + + public override bool WillEncode(int unicodeScalar) + { + if (unicodeScalar > char.MaxValue) + { + return false; + } + + return NeedsEncoding((char)unicodeScalar); + } + + // https://datatracker.ietf.org/doc/html/rfc8259#section-7 + private static bool NeedsEncoding(char value) + { + if (value == '"' || value == '\\') + { + return true; + } + + return value <= '\u001f'; + } + + private static bool HasTwoCharacterEscape(char value) + { + // RFC 8259, Section 7, "char = " BNF + switch (value) + { + case '"': + case '\\': + case '/': + case '\b': + case '\f': + case '\n': + case '\r': + case '\t': + return true; + default: + return false; + } + } + + private static char GetTwoCharacterEscapeSuffix(char value) + { + // RFC 8259, Section 7, "char = " BNF + switch (value) + { + case '"': + return '"'; + case '\\': + return '\\'; + case '/': + return '/'; + case '\b': + return 'b'; + case '\f': + return 'f'; + case '\n': + return 'n'; + case '\r': + return 'r'; + case '\t': + return 't'; + default: + throw new ArgumentOutOfRangeException(nameof(value)); + } + } + + private static char ToHexDigit(int value, bool uppercase) + { + if (value > 0xf) + { + throw new ArgumentOutOfRangeException(nameof(value)); + } + + if (value < 10) + { + return (char)(value + '0'); + } + else + { + return (char)(value - 0xa + (uppercase ? 'A' : 'a')); + } + } +} \ No newline at end of file