diff options
Diffstat (limited to 'LibMatrix/Extensions')
-rw-r--r-- | LibMatrix/Extensions/ClassCollector.cs | 22 | ||||
-rw-r--r-- | LibMatrix/Extensions/DictionaryExtensions.cs | 33 | ||||
-rw-r--r-- | LibMatrix/Extensions/HttpClientExtensions.cs | 76 | ||||
-rw-r--r-- | LibMatrix/Extensions/IEnumerableExtensions.cs | 7 | ||||
-rw-r--r-- | LibMatrix/Extensions/JsonElementExtensions.cs | 150 | ||||
-rw-r--r-- | LibMatrix/Extensions/ObjectExtensions.cs | 14 | ||||
-rw-r--r-- | LibMatrix/Extensions/StringExtensions.cs | 13 |
7 files changed, 315 insertions, 0 deletions
diff --git a/LibMatrix/Extensions/ClassCollector.cs b/LibMatrix/Extensions/ClassCollector.cs new file mode 100644 index 0000000..f53850a --- /dev/null +++ b/LibMatrix/Extensions/ClassCollector.cs @@ -0,0 +1,22 @@ +using System.Reflection; + +namespace LibMatrix.Extensions; + +public class ClassCollector<T> where T : class { + static ClassCollector() { + if (!typeof(T).IsInterface) + throw new ArgumentException( + $"ClassCollector<T> must be used with an interface type. Passed type: {typeof(T).Name}"); + } + + public List<Type> ResolveFromAllAccessibleAssemblies() => AppDomain.CurrentDomain.GetAssemblies().SelectMany(ResolveFromAssembly).ToList(); + + public List<Type> ResolveFromObjectReference(object obj) => ResolveFromTypeReference(obj.GetType()); + + public List<Type> ResolveFromTypeReference(Type t) => Assembly.GetAssembly(t)?.GetReferencedAssemblies().SelectMany(ResolveFromAssemblyName).ToList() ?? new List<Type>(); + + public List<Type> ResolveFromAssemblyName(AssemblyName assemblyName) => ResolveFromAssembly(Assembly.Load(assemblyName)); + + public List<Type> ResolveFromAssembly(Assembly assembly) => assembly.GetTypes() + .Where(x => x is { IsClass: true, IsAbstract: false } && x.GetInterfaces().Contains(typeof(T))).ToList(); +} diff --git a/LibMatrix/Extensions/DictionaryExtensions.cs b/LibMatrix/Extensions/DictionaryExtensions.cs new file mode 100644 index 0000000..fbc5cf5 --- /dev/null +++ b/LibMatrix/Extensions/DictionaryExtensions.cs @@ -0,0 +1,33 @@ +namespace LibMatrix.Extensions; + +public static class DictionaryExtensions { + public static bool ChangeKey<TKey, TValue>(this IDictionary<TKey, TValue> dict, + TKey oldKey, TKey newKey) { + TValue value; + if (!dict.Remove(oldKey, out value)) + return false; + + dict[newKey] = value; // or dict.Add(newKey, value) depending on ur comfort + return true; + } + + public static Y GetOrCreate<X, Y>(this IDictionary<X, Y> dict, X key) where Y : new() { + if (dict.TryGetValue(key, out var value)) { + return value; + } + + value = new Y(); + dict.Add(key, value); + return value; + } + + public static Y GetOrCreate<X, Y>(this IDictionary<X, Y> dict, X key, Func<X, Y> valueFactory) { + if (dict.TryGetValue(key, out var value)) { + return value; + } + + value = valueFactory(key); + dict.Add(key, value); + return value; + } +} diff --git a/LibMatrix/Extensions/HttpClientExtensions.cs b/LibMatrix/Extensions/HttpClientExtensions.cs new file mode 100644 index 0000000..797a077 --- /dev/null +++ b/LibMatrix/Extensions/HttpClientExtensions.cs @@ -0,0 +1,76 @@ +using System.Net.Http.Headers; +using System.Reflection; +using System.Text.Json; + +namespace LibMatrix.Extensions; + +public static class HttpClientExtensions { + public static async Task<bool> CheckSuccessStatus(this HttpClient hc, string url) { + //cors causes failure, try to catch + try { + var resp = await hc.GetAsync(url); + return resp.IsSuccessStatusCode; + } + catch (Exception e) { + Console.WriteLine($"Failed to check success status: {e.Message}"); + return false; + } + } +} + +public class MatrixHttpClient : HttpClient { + public override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, + CancellationToken cancellationToken) { + Console.WriteLine($"Sending request to {request.RequestUri}"); + try { + HttpRequestOptionsKey<bool> WebAssemblyEnableStreamingResponseKey = + new HttpRequestOptionsKey<bool>("WebAssemblyEnableStreamingResponse"); + request.Options.Set(WebAssemblyEnableStreamingResponseKey, true); + } + catch (Exception e) { + Console.WriteLine("Failed to set browser response streaming:"); + Console.WriteLine(e); + } + + var a = await base.SendAsync(request, cancellationToken); + if (!a.IsSuccessStatusCode) { + var content = await a.Content.ReadAsStringAsync(cancellationToken); + if (content.StartsWith('{')) { + var ex = JsonSerializer.Deserialize<MatrixException>(content); + ex.RawContent = content; + // Console.WriteLine($"Failed to send request: {ex}"); + if (ex?.RetryAfterMs is not null) { + await Task.Delay(ex.RetryAfterMs.Value, cancellationToken); + typeof(HttpRequestMessage).GetField("_sendStatus", BindingFlags.NonPublic | BindingFlags.Instance) + ?.SetValue(request, 0); + return await SendAsync(request, cancellationToken); + } + + throw ex!; + } + + throw new InvalidDataException("Encountered invalid data:\n" + content); + } + + return a; + } + + // GetFromJsonAsync + public async Task<T> GetFromJsonAsync<T>(string requestUri, CancellationToken cancellationToken = default) { + var request = new HttpRequestMessage(HttpMethod.Get, requestUri); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + var response = await SendAsync(request, cancellationToken); + response.EnsureSuccessStatusCode(); + await using var responseStream = await response.Content.ReadAsStreamAsync(cancellationToken); + return await JsonSerializer.DeserializeAsync<T>(responseStream, cancellationToken: cancellationToken); + } + + // GetStreamAsync + public async Task<Stream> GetStreamAsync(string requestUri, CancellationToken cancellationToken = default) { + var request = new HttpRequestMessage(HttpMethod.Get, requestUri); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + var response = await SendAsync(request, cancellationToken); + response.EnsureSuccessStatusCode(); + return await response.Content.ReadAsStreamAsync(cancellationToken); + } +} diff --git a/LibMatrix/Extensions/IEnumerableExtensions.cs b/LibMatrix/Extensions/IEnumerableExtensions.cs new file mode 100644 index 0000000..8124947 --- /dev/null +++ b/LibMatrix/Extensions/IEnumerableExtensions.cs @@ -0,0 +1,7 @@ +namespace LibMatrix.Extensions; + +[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] +public class MatrixEventAttribute : Attribute { + public string EventName { get; set; } + public bool Legacy { get; set; } +} diff --git a/LibMatrix/Extensions/JsonElementExtensions.cs b/LibMatrix/Extensions/JsonElementExtensions.cs new file mode 100644 index 0000000..caf96e1 --- /dev/null +++ b/LibMatrix/Extensions/JsonElementExtensions.cs @@ -0,0 +1,150 @@ +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using LibMatrix.Responses; + +namespace LibMatrix.Extensions; + +public static class JsonElementExtensions { + public static bool FindExtraJsonElementFields(this JsonElement obj, Type objectType, string objectPropertyName) { + if (objectPropertyName == "content" && objectType == typeof(JsonObject)) + objectType = typeof(StateEventResponse); + // if (t == typeof(JsonNode)) + // return false; + + Console.WriteLine($"{objectType.Name} {objectPropertyName}"); + bool unknownPropertyFound = false; + var mappedPropsDict = objectType.GetProperties() + .Where(x => x.GetCustomAttribute<JsonPropertyNameAttribute>() is not null) + .ToDictionary(x => x.GetCustomAttribute<JsonPropertyNameAttribute>()!.Name, x => x); + objectType.GetProperties().Where(x => !mappedPropsDict.ContainsKey(x.Name)) + .ToList().ForEach(x => mappedPropsDict.TryAdd(x.Name, x)); + + foreach (var field in obj.EnumerateObject()) { + if (mappedPropsDict.TryGetValue(field.Name, out var mappedProperty)) { + //dictionary + if (mappedProperty.PropertyType.IsGenericType && + mappedProperty.PropertyType.GetGenericTypeDefinition() == typeof(Dictionary<,>)) { + unknownPropertyFound |= _checkDictionary(field, objectType, mappedProperty.PropertyType); + continue; + } + + if (mappedProperty.PropertyType.IsGenericType && + mappedProperty.PropertyType.GetGenericTypeDefinition() == typeof(List<>)) { + unknownPropertyFound |= _checkList(field, objectType, mappedProperty.PropertyType); + continue; + } + + if (field.Name == "content" && (objectType == typeof(StateEventResponse) || objectType == typeof(StateEvent))) { + unknownPropertyFound |= field.FindExtraJsonPropertyFieldsByValueKind( + StateEvent.GetStateEventType(obj.GetProperty("type").GetString()), + mappedProperty.PropertyType); + continue; + } + + unknownPropertyFound |= + field.FindExtraJsonPropertyFieldsByValueKind(objectType, mappedProperty.PropertyType); + continue; + } + + Console.WriteLine($"[!!] Unknown property {field.Name} in {objectType.Name}!"); + unknownPropertyFound = true; + } + + return unknownPropertyFound; + } + + private static bool FindExtraJsonPropertyFieldsByValueKind(this JsonProperty field, Type containerType, + Type propertyType) { + if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(Nullable<>)) { + propertyType = propertyType.GetGenericArguments()[0]; + } + + bool switchResult = false; + switch (field.Value.ValueKind) { + case JsonValueKind.Array: + switchResult = field.Value.EnumerateArray().Aggregate(switchResult, + (current, element) => current | element.FindExtraJsonElementFields(propertyType, field.Name)); + break; + case JsonValueKind.Object: + switchResult |= field.Value.FindExtraJsonElementFields(propertyType, field.Name); + break; + case JsonValueKind.True: + case JsonValueKind.False: + return _checkBool(field, containerType, propertyType); + case JsonValueKind.String: + return _checkString(field, containerType, propertyType); + case JsonValueKind.Number: + return _checkNumber(field, containerType, propertyType); + case JsonValueKind.Undefined: + case JsonValueKind.Null: + break; + default: + throw new ArgumentOutOfRangeException(); + } + + return switchResult; + } + + private static bool _checkBool(this JsonProperty field, Type containerType, Type propertyType) { + if (propertyType == typeof(bool)) return true; + Console.WriteLine( + $"[!!] Encountered bool for {field.Name} in {containerType.Name}, the class defines {propertyType.Name}!"); + return false; + } + + private static bool _checkString(this JsonProperty field, Type containerType, Type propertyType) { + if (propertyType == typeof(string)) return true; + // ReSharper disable once BuiltInTypeReferenceStyle + if (propertyType == typeof(String)) return true; + Console.WriteLine( + $"[!!] Encountered string for {field.Name} in {containerType.Name}, the class defines {propertyType.Name}!"); + return false; + } + + private static bool _checkNumber(this JsonProperty field, Type containerType, Type propertyType) { + if (propertyType == typeof(int) || + propertyType == typeof(double) || + propertyType == typeof(float) || + propertyType == typeof(decimal) || + propertyType == typeof(long) || + propertyType == typeof(short) || + propertyType == typeof(uint) || + propertyType == typeof(ulong) || + propertyType == typeof(ushort) || + propertyType == typeof(byte) || + propertyType == typeof(sbyte)) + return true; + Console.WriteLine( + $"[!!] Encountered number for {field.Name} in {containerType.Name}, the class defines {propertyType.Name}!"); + return false; + } + + private static bool _checkDictionary(this JsonProperty field, Type containerType, Type propertyType) { + var keyType = propertyType.GetGenericArguments()[0]; + var valueType = propertyType.GetGenericArguments()[1]; + valueType = Nullable.GetUnderlyingType(valueType) ?? valueType; + Console.WriteLine( + $"Encountered dictionary {field.Name} with key type {keyType.Name} and value type {valueType.Name}!"); + + return field.Value.EnumerateObject() + .Where(key => !valueType.IsPrimitive && valueType != typeof(string)) + .Aggregate(false, (current, key) => + current | key.FindExtraJsonPropertyFieldsByValueKind(containerType, valueType) + ); + } + + private static bool _checkList(this JsonProperty field, Type containerType, Type propertyType) { + var valueType = propertyType.GetGenericArguments()[0]; + valueType = Nullable.GetUnderlyingType(valueType) ?? valueType; + Console.WriteLine( + $"Encountered list {field.Name} with value type {valueType.Name}!"); + + return field.Value.EnumerateArray() + .Where(key => !valueType.IsPrimitive && valueType != typeof(string)) + .Aggregate(false, (current, key) => + current | key.FindExtraJsonElementFields(valueType, field.Name) + ); + } +} diff --git a/LibMatrix/Extensions/ObjectExtensions.cs b/LibMatrix/Extensions/ObjectExtensions.cs new file mode 100644 index 0000000..085de7d --- /dev/null +++ b/LibMatrix/Extensions/ObjectExtensions.cs @@ -0,0 +1,14 @@ +using System.Text.Encodings.Web; +using System.Text.Json; + +namespace LibMatrix.Extensions; + +public static class ObjectExtensions { + public static string ToJson(this object obj, bool indent = true, bool ignoreNull = false, bool unsafeContent = false) { + var jso = new JsonSerializerOptions(); + if (indent) jso.WriteIndented = true; + if (ignoreNull) jso.IgnoreNullValues = true; + if (unsafeContent) jso.Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping; + return JsonSerializer.Serialize(obj, jso); + } +} diff --git a/LibMatrix/Extensions/StringExtensions.cs b/LibMatrix/Extensions/StringExtensions.cs new file mode 100644 index 0000000..491fa77 --- /dev/null +++ b/LibMatrix/Extensions/StringExtensions.cs @@ -0,0 +1,13 @@ +namespace LibMatrix.Extensions; + +public static class StringExtensions { + // public static async Task<string> GetMediaUrl(this string MxcUrl) + // { + // //MxcUrl: mxc://rory.gay/ocRVanZoUTCcifcVNwXgbtTg + // //target: https://matrix.rory.gay/_matrix/media/v3/download/rory.gay/ocRVanZoUTCcifcVNwXgbtTg + // + // var server = MxcUrl.Split('/')[2]; + // var mediaId = MxcUrl.Split('/')[3]; + // return $"{(await new RemoteHomeServer(server).Configure()).FullHomeServerDomain}/_matrix/media/v3/download/{server}/{mediaId}"; + // } +} |