about summary refs log tree commit diff
path: root/LibMatrix
diff options
context:
space:
mode:
authorRory& <root@rory.gay>2024-04-19 15:54:30 +0200
committerRory& <root@rory.gay>2024-04-19 15:54:30 +0200
commit440807e02393410327cd86d5ffa007dee98f8954 (patch)
treee750b0bab55a9ee7b507cd48eaa4ccb2ddd25fc0 /LibMatrix
parentFix homeserver resolution, rewrite homeserver initialisation, HSE work (diff)
downloadLibMatrix-440807e02393410327cd86d5ffa007dee98f8954.tar.xz
Partial User-Interactive Authentication, allow skipping homeserver typing
Diffstat (limited to 'LibMatrix')
-rw-r--r--LibMatrix/Extensions/HttpClientExtensions.cs21
-rw-r--r--LibMatrix/Homeservers/AuthenticatedHomeserverGeneric.cs18
-rw-r--r--LibMatrix/Homeservers/RemoteHomeServer.cs4
-rw-r--r--LibMatrix/Homeservers/UserInteractiveAuthClient.cs85
-rw-r--r--LibMatrix/RoomTypes/GenericRoom.cs28
-rw-r--r--LibMatrix/Services/HomeserverProviderService.cs69
-rw-r--r--LibMatrix/StateEvent.cs13
7 files changed, 185 insertions, 53 deletions
diff --git a/LibMatrix/Extensions/HttpClientExtensions.cs b/LibMatrix/Extensions/HttpClientExtensions.cs
index 598f8e5..01ce6ea 100644
--- a/LibMatrix/Extensions/HttpClientExtensions.cs
+++ b/LibMatrix/Extensions/HttpClientExtensions.cs
@@ -38,7 +38,7 @@ public class MatrixHttpClient : HttpClient {
         return options;
     }
 
-    public async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) {
+    public async Task<HttpResponseMessage> SendUnhandledAsync(HttpRequestMessage request, CancellationToken cancellationToken) {
         Console.WriteLine($"Sending {request.Method} {BaseAddress}{request.RequestUri} ({Util.BytesToString(request.Content?.Headers.ContentLength ?? 0)})");
         if (request.RequestUri is null) throw new NullReferenceException("RequestUri is null");
         if (!request.RequestUri.IsAbsoluteUri) request.RequestUri = new Uri(BaseAddress, request.RequestUri);
@@ -57,20 +57,13 @@ public class MatrixHttpClient : HttpClient {
             Console.WriteLine(e);
         }
 
-        HttpResponseMessage responseMessage;
-        // try {
-        responseMessage = await base.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
-        // }
-        // catch (Exception e) {
-        // if (requestSettings is { Retries: 0 }) throw;
-        // typeof(HttpRequestMessage).GetField("_sendStatus", BindingFlags.NonPublic | BindingFlags.Instance)
-        // ?.SetValue(request, 0);
-        // await Task.Delay(requestSettings?.RetryDelay ?? 2500, cancellationToken);
-        // if(requestSettings is not null) requestSettings.Retries--;
-        // return await SendAsync(request, cancellationToken);
-        // throw;
-        // }
+        var responseMessage = await base.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
+
+        return responseMessage;
+    }
 
+    public async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) {
+        var responseMessage = await SendUnhandledAsync(request, cancellationToken);
         if (responseMessage.IsSuccessStatusCode) return responseMessage;
 
         //error handling
diff --git a/LibMatrix/Homeservers/AuthenticatedHomeserverGeneric.cs b/LibMatrix/Homeservers/AuthenticatedHomeserverGeneric.cs
index afa6a6c..267b54d 100644
--- a/LibMatrix/Homeservers/AuthenticatedHomeserverGeneric.cs
+++ b/LibMatrix/Homeservers/AuthenticatedHomeserverGeneric.cs
@@ -128,10 +128,20 @@ public class AuthenticatedHomeserverGeneric : RemoteHomeserver {
     public virtual async IAsyncEnumerable<GenericRoom> GetJoinedRoomsByType(string type) {
         var rooms = await GetJoinedRooms();
         var tasks = rooms.Select(async room => {
-            var roomType = await room.GetRoomType();
-            if (roomType == type) return room;
-
-            return null;
+            while (true) {
+                try {
+                    var roomType = await room.GetRoomType();
+                    if (roomType == type) return room;
+                    return null;
+                }
+                catch (MatrixException e) {
+                    throw;
+                }
+                catch (Exception e) {
+                    Console.WriteLine($"Failed to get room type for {room.RoomId}: {e.Message}");
+                    await Task.Delay(1000);
+                }
+            }
         }).ToAsyncEnumerable();
 
         await foreach (var result in tasks)
diff --git a/LibMatrix/Homeservers/RemoteHomeServer.cs b/LibMatrix/Homeservers/RemoteHomeServer.cs
index e6d58b1..c29137c 100644
--- a/LibMatrix/Homeservers/RemoteHomeServer.cs
+++ b/LibMatrix/Homeservers/RemoteHomeServer.cs
@@ -1,5 +1,6 @@
 using System.Net.Http.Json;
 using System.Text.Json;
+using System.Text.Json.Nodes;
 using System.Text.Json.Serialization;
 using System.Web;
 using ArcaneLibs.Extensions;
@@ -24,6 +25,7 @@ public class RemoteHomeserver {
         if (proxy is not null) ClientHttpClient.DefaultRequestHeaders.Add("MXAE_UPSTREAM", baseUrl);
         if (!string.IsNullOrWhiteSpace(wellKnownUris.Server))
             FederationClient = new FederationClient(WellKnownUris.Server!, proxy);
+        Auth = new(this);
     }
 
     private Dictionary<string, object> _profileCache { get; set; } = new();
@@ -106,6 +108,8 @@ public class RemoteHomeserver {
         if (mxcUri.StartsWith("https://")) return mxcUri;
         return $"{ClientHttpClient.BaseAddress}/_matrix/media/v3/download/{mxcUri.Replace("mxc://", "")}".Replace("//_matrix", "/_matrix");
     }
+
+    public UserInteractiveAuthClient Auth;
 }
 
 public class AliasResult {
diff --git a/LibMatrix/Homeservers/UserInteractiveAuthClient.cs b/LibMatrix/Homeservers/UserInteractiveAuthClient.cs
new file mode 100644
index 0000000..8be2cb9
--- /dev/null
+++ b/LibMatrix/Homeservers/UserInteractiveAuthClient.cs
@@ -0,0 +1,85 @@
+using System.Net.Http.Json;
+using System.Text.Json.Nodes;
+using System.Text.Json.Serialization;
+using ArcaneLibs.Extensions;
+using LibMatrix.Responses;
+
+namespace LibMatrix.Homeservers;
+
+public class UserInteractiveAuthClient {
+    public UserInteractiveAuthClient(RemoteHomeserver hs) {
+        Homeserver = hs;
+    }
+
+    [JsonIgnore]
+    public RemoteHomeserver Homeserver { get; }
+    private LoginResponse? _guestLogin;
+
+    public async Task<UIAStage1Client> GetAvailableFlowsAsync(bool enableRegister = false, bool enableGuest = false) {
+        // var resp = await Homeserver.ClientHttpClient.GetAsync("/_matrix/client/v3/login");
+        // var data = await resp.Content.ReadFromJsonAsync<LoginFlowsResponse>();
+        // if (!resp.IsSuccessStatusCode) Console.WriteLine("LoginFlows: " + await resp.Content.ReadAsStringAsync());
+        // var loginFlows = data;
+        //
+        // try {
+        //     var req = new HttpRequestMessage(HttpMethod.Post, "/_matrix/client/v3/register") {
+        //         Content = new StringContent("{}")
+        //     };
+        //     var resp2 = await Homeserver.ClientHttpClient.SendUnhandledAsync(req, CancellationToken.None);
+        //     var data2 = await resp2.Content.ReadFromJsonAsync<RegisterFlowsResponse>();
+        //     if (!resp.IsSuccessStatusCode) Console.WriteLine("RegisterFlows: " + data2.ToJson());
+        //     // return data;
+        // }
+        // catch (MatrixException e) {
+        //     if (e is { ErrorCode: "M_FORBIDDEN" }) return null;
+        //     throw;
+        // }
+        // catch (Exception e) {
+        //     Console.WriteLine(e);
+        //     throw;
+        // }
+        //
+        //
+        return new UIAStage1Client() {
+            
+        };
+    }
+
+    private async Task<RegisterFlowsResponse?> GetRegisterFlowsAsync() {
+        return null;
+    }
+
+    internal class RegisterFlowsResponse {
+        [JsonPropertyName("session")]
+        public string Session { get; set; } = null!;
+
+        [JsonPropertyName("flows")]
+        public List<RegisterFlow> Flows { get; set; } = null!;
+
+        [JsonPropertyName("params")]
+        public JsonObject Params { get; set; } = null!;
+
+        public class RegisterFlow {
+            [JsonPropertyName("stages")]
+            public List<string> Stages { get; set; } = null!;
+        }
+    }
+
+    internal class LoginFlowsResponse {
+        [JsonPropertyName("flows")]
+        public List<LoginFlow> Flows { get; set; } = null!;
+
+        public class LoginFlow {
+            [JsonPropertyName("type")]
+            public string Type { get; set; } = null!;
+        }
+    }
+
+    public interface IUIAStage {
+        public IUIAStage? PreviousStage { get; }
+    }
+    public class UIAStage1Client : IUIAStage {
+        public IUIAStage? PreviousStage { get; }
+        // public LoginFlowsResponse LoginFlows { get; set; }
+    }
+}
\ No newline at end of file
diff --git a/LibMatrix/RoomTypes/GenericRoom.cs b/LibMatrix/RoomTypes/GenericRoom.cs
index 36abadc..e4d2b9c 100644
--- a/LibMatrix/RoomTypes/GenericRoom.cs
+++ b/LibMatrix/RoomTypes/GenericRoom.cs
@@ -160,7 +160,7 @@ public class GenericRoom {
         Console.WriteLine("End of GetManyAsync");
     }
 
-    public async Task<string?> GetNameAsync() => (await GetStateAsync<RoomNameEventContent>("m.room.name"))?.Name;
+    public async Task<string?> GetNameAsync() => (await GetStateOrNullAsync<RoomNameEventContent>("m.room.name"))?.Name;
 
     public async Task<RoomIdResponse> JoinAsync(string[]? homeservers = null, string? reason = null, bool checkIfAlreadyMember = true) {
         if (checkIfAlreadyMember)
@@ -406,7 +406,7 @@ public class GenericRoom {
         }
     }
 
-    public Task<T> GetEventAsync<T>(string eventId) => Homeserver.ClientHttpClient.GetFromJsonAsync<T>($"/_matrix/client/v3/rooms/{RoomId}/event/{eventId}");
+    public Task<StateEventResponse> GetEventAsync(string eventId) => Homeserver.ClientHttpClient.GetFromJsonAsync<StateEventResponse>($"/_matrix/client/v3/rooms/{RoomId}/event/{eventId}");
 
     public async Task<EventIdResponse> RedactEventAsync(string eventToRedact, string reason) {
         var data = new { reason };
@@ -465,6 +465,30 @@ public class GenericRoom {
 
 #endregion
 
+    public async IAsyncEnumerable<StateEventResponse> GetRelatedEventsAsync(string eventId, string? relationType = null, string? eventType = null, string? dir = "f",
+        string? from = null, int? chunkLimit = 100, bool? recurse = false, string? to = null) {
+        var path = $"/_matrix/client/v3/rooms/{RoomId}/relations/{eventId}";
+        if (!string.IsNullOrEmpty(relationType)) path += $"/{relationType}";
+        if (!string.IsNullOrEmpty(eventType)) path += $"/{eventType}";
+
+        var uri = new Uri(path, UriKind.Relative);
+        if (dir == "b" || dir == "f") uri = uri.AddQuery("dir", dir);
+        if (!string.IsNullOrEmpty(from)) uri = uri.AddQuery("from", from);
+        if (chunkLimit is not null) uri = uri.AddQuery("limit", chunkLimit.Value.ToString());
+        if (recurse is not null) uri = uri.AddQuery("recurse", recurse.Value.ToString());
+        if (!string.IsNullOrEmpty(to)) uri = uri.AddQuery("to", to);
+
+        var result = await Homeserver.ClientHttpClient.GetFromJsonAsync<RecursedBatchedChunkedStateEventResponse>(uri);
+        while (result!.Chunk.Count > 0) {
+            foreach (var resp in result.Chunk) {
+                yield return resp;
+            }
+
+            if (result.NextBatch is null) break;
+            result = await Homeserver.ClientHttpClient.GetFromJsonAsync<RecursedBatchedChunkedStateEventResponse>(uri.AddQuery("from", result.NextBatch));
+        }
+    }
+
     public readonly SpaceRoom AsSpace;
 }
 
diff --git a/LibMatrix/Services/HomeserverProviderService.cs b/LibMatrix/Services/HomeserverProviderService.cs
index 3995a26..c61ef73 100644
--- a/LibMatrix/Services/HomeserverProviderService.cs
+++ b/LibMatrix/Services/HomeserverProviderService.cs
@@ -11,43 +11,47 @@ public class HomeserverProviderService(ILogger<HomeserverProviderService> logger
     private static SemaphoreCache<AuthenticatedHomeserverGeneric> AuthenticatedHomeserverCache = new();
     private static SemaphoreCache<RemoteHomeserver> RemoteHomeserverCache = new();
 
-    public async Task<AuthenticatedHomeserverGeneric> GetAuthenticatedWithToken(string homeserver, string accessToken, string? proxy = null, string? impersonatedMxid = null) {
+    public async Task<AuthenticatedHomeserverGeneric> GetAuthenticatedWithToken(string homeserver, string accessToken, string? proxy = null, string? impersonatedMxid = null,
+        bool useGeneric = false) {
         return await AuthenticatedHomeserverCache.GetOrAdd($"{homeserver}{accessToken}{proxy}{impersonatedMxid}", async () => {
             var wellKnownUris = await hsResolver.ResolveHomeserverFromWellKnown(homeserver);
             var rhs = new RemoteHomeserver(homeserver, wellKnownUris, ref proxy);
+            
+            AuthenticatedHomeserverGeneric? hs = null;
+            if (!useGeneric)
+            {
+                ClientVersionsResponse? clientVersions = new();
+                try {
+                    clientVersions = await rhs.GetClientVersionsAsync();
+                }
+                catch (Exception e) {
+                    logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver);
+                }
 
-            ClientVersionsResponse? clientVersions = new();
-            try {
-                clientVersions = await rhs.GetClientVersionsAsync();
-            }
-            catch (Exception e) {
-                logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver);
-            }
-
-            ServerVersionResponse? serverVersion;
-            try {
-                serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult<ServerVersionResponse?>(null)!);
-            }
-            catch (Exception e) {
-                logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver);
-                throw;
-            }
+                ServerVersionResponse? serverVersion;
+                try {
+                    serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult<ServerVersionResponse?>(null)!);
+                }
+                catch (Exception e) {
+                    logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver);
+                    throw;
+                }
 
-            AuthenticatedHomeserverGeneric hs;
-            try {
-                if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a)
-                    hs = new AuthenticatedHomeserverMxApiExtended(homeserver, wellKnownUris, ref proxy, accessToken);
-                else {
-                    if (serverVersion is { Server.Name: "Synapse" })
-                        hs = new AuthenticatedHomeserverSynapse(homeserver, wellKnownUris, ref proxy, accessToken);
-                    else
-                        hs = new AuthenticatedHomeserverGeneric(homeserver, wellKnownUris, ref proxy, accessToken);
+                try {
+                    if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a)
+                        hs = new AuthenticatedHomeserverMxApiExtended(homeserver, wellKnownUris, ref proxy, accessToken);
+                    else {
+                        if (serverVersion is { Server.Name: "Synapse" })
+                            hs = new AuthenticatedHomeserverSynapse(homeserver, wellKnownUris, ref proxy, accessToken);
+                    }
+                }
+                catch (Exception e) {
+                    logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver);
+                    throw;
                 }
             }
-            catch (Exception e) {
-                logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver);
-                throw;
-            }
+            
+            hs ??= new AuthenticatedHomeserverGeneric(homeserver, wellKnownUris, ref proxy, accessToken);
 
             await hs.Initialise();
 
@@ -59,9 +63,8 @@ public class HomeserverProviderService(ILogger<HomeserverProviderService> logger
     }
 
     public async Task<RemoteHomeserver> GetRemoteHomeserver(string homeserver, string? proxy = null) =>
-        await RemoteHomeserverCache.GetOrAdd($"{homeserver}{proxy}", async () => {
-            return new RemoteHomeserver(homeserver, await hsResolver.ResolveHomeserverFromWellKnown(homeserver), ref proxy);
-        });
+        await RemoteHomeserverCache.GetOrAdd($"{homeserver}{proxy}",
+            async () => { return new RemoteHomeserver(homeserver, await hsResolver.ResolveHomeserverFromWellKnown(homeserver), ref proxy); });
 
     public async Task<LoginResponse> Login(string homeserver, string user, string password, string? proxy = null) {
         var hs = await GetRemoteHomeserver(homeserver, proxy);
diff --git a/LibMatrix/StateEvent.cs b/LibMatrix/StateEvent.cs
index 26c6a5f..f504c99 100644
--- a/LibMatrix/StateEvent.cs
+++ b/LibMatrix/StateEvent.cs
@@ -203,6 +203,19 @@ public class PaginatedChunkedStateEventResponse : ChunkedStateEventResponse {
     public string? End { get; set; }
 }
 
+public class BatchedChunkedStateEventResponse : ChunkedStateEventResponse {
+    [JsonPropertyName("next_batch")]
+    public string? NextBatch { get; set; }
+    
+    [JsonPropertyName("prev_batch")]
+    public string? PrevBatch { get; set; }
+}
+
+public class RecursedBatchedChunkedStateEventResponse : BatchedChunkedStateEventResponse {
+    [JsonPropertyName("recursion_depth")]
+    public int? RecursionDepth { get; set; }
+}
+
 #region Unused code
 
 /*