about summary refs log tree commit diff
path: root/LibMatrix/Services/HomeserverProviderService.cs
diff options
context:
space:
mode:
Diffstat (limited to 'LibMatrix/Services/HomeserverProviderService.cs')
-rw-r--r--LibMatrix/Services/HomeserverProviderService.cs126
1 files changed, 45 insertions, 81 deletions
diff --git a/LibMatrix/Services/HomeserverProviderService.cs b/LibMatrix/Services/HomeserverProviderService.cs
index 8e2e15b..3995a26 100644
--- a/LibMatrix/Services/HomeserverProviderService.cs
+++ b/LibMatrix/Services/HomeserverProviderService.cs
@@ -1,4 +1,5 @@
 using System.Net.Http.Json;
+using ArcaneLibs.Collections;
 using ArcaneLibs.Extensions;
 using LibMatrix.Homeservers;
 using LibMatrix.Responses;
@@ -6,99 +7,62 @@ using Microsoft.Extensions.Logging;
 
 namespace LibMatrix.Services;
 
-public class HomeserverProviderService(ILogger<HomeserverProviderService> logger) {
-    private static readonly Dictionary<string, SemaphoreSlim> AuthenticatedHomeserverSemaphore = new();
-    private static readonly Dictionary<string, AuthenticatedHomeserverGeneric> AuthenticatedHomeserverCache = new();
-
-    private static readonly Dictionary<string, SemaphoreSlim> RemoteHomeserverSemaphore = new();
-    private static readonly Dictionary<string, RemoteHomeserver> RemoteHomeserverCache = new();
+public class HomeserverProviderService(ILogger<HomeserverProviderService> logger, HomeserverResolverService hsResolver) {
+    private static SemaphoreCache<AuthenticatedHomeserverGeneric> AuthenticatedHomeserverCache = new();
+    private static SemaphoreCache<RemoteHomeserver> RemoteHomeserverCache = new();
 
     public async Task<AuthenticatedHomeserverGeneric> GetAuthenticatedWithToken(string homeserver, string accessToken, string? proxy = null, string? impersonatedMxid = null) {
-        var cacheKey = homeserver + accessToken + proxy + impersonatedMxid;
-        var sem = AuthenticatedHomeserverSemaphore.GetOrCreate(cacheKey, _ => new SemaphoreSlim(1, 1));
-        await sem.WaitAsync();
-        AuthenticatedHomeserverGeneric? hs;
-        lock (AuthenticatedHomeserverCache) {
-            if (AuthenticatedHomeserverCache.TryGetValue(cacheKey, out hs)) {
-                sem.Release();
-                return hs;
-            }
-        }
-
-        var rhs = await RemoteHomeserver.Create(homeserver, proxy);
-        ClientVersionsResponse clientVersions = new();
-        try {
-            clientVersions = await rhs.GetClientVersionsAsync();
-        }
-        catch (Exception e) {
-            logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver);
-        }
-
-        if (proxy is not null)
-            logger.LogInformation("Homeserver {homeserver} proxied via {proxy}...", homeserver, proxy);
-        logger.LogInformation("{homeserver}: {clientVersions}", homeserver, clientVersions.ToJson());
+        return await AuthenticatedHomeserverCache.GetOrAdd($"{homeserver}{accessToken}{proxy}{impersonatedMxid}", async () => {
+            var wellKnownUris = await hsResolver.ResolveHomeserverFromWellKnown(homeserver);
+            var rhs = new RemoteHomeserver(homeserver, wellKnownUris, ref proxy);
 
-        ServerVersionResponse serverVersion;
-        try {
-            serverVersion = serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult<ServerVersionResponse?>(null)!);
-        }
-        catch (Exception e) {
-            logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver);
-            sem.Release();
-            throw;
-        }
-
-        try {
-            if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a)
-                hs = await AuthenticatedHomeserverGeneric.Create<AuthenticatedHomeserverMxApiExtended>(homeserver, accessToken, proxy);
-            else {
-                if (serverVersion is { Server.Name: "Synapse" })
-                    hs = await AuthenticatedHomeserverGeneric.Create<AuthenticatedHomeserverSynapse>(homeserver, accessToken, proxy);
-                else
-                    hs = await AuthenticatedHomeserverGeneric.Create<AuthenticatedHomeserverGeneric>(homeserver, accessToken, proxy);
+            ClientVersionsResponse? clientVersions = new();
+            try {
+                clientVersions = await rhs.GetClientVersionsAsync();
+            }
+            catch (Exception e) {
+                logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver);
             }
-        }
-        catch (Exception e) {
-            logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver);
-            sem.Release();
-            throw;
-        }
-
-        if (impersonatedMxid is not null)
-            await hs.SetImpersonate(impersonatedMxid);
-
-        lock (AuthenticatedHomeserverCache) {
-            AuthenticatedHomeserverCache[cacheKey] = hs;
-        }
-
-        sem.Release();
-
-        return hs;
-    }
 
-    public async Task<RemoteHomeserver> GetRemoteHomeserver(string homeserver, string? proxy = null) {
-        var cacheKey = homeserver + proxy;
-        var sem = RemoteHomeserverSemaphore.GetOrCreate(cacheKey, _ => new SemaphoreSlim(1, 1));
-        await sem.WaitAsync();
-        RemoteHomeserver? hs;
-        lock (RemoteHomeserverCache) {
-            if (RemoteHomeserverCache.TryGetValue(cacheKey, out hs)) {
-                sem.Release();
-                return hs;
+            ServerVersionResponse? serverVersion;
+            try {
+                serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult<ServerVersionResponse?>(null)!);
+            }
+            catch (Exception e) {
+                logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver);
+                throw;
             }
-        }
 
-        hs = await RemoteHomeserver.Create(homeserver, proxy);
+            AuthenticatedHomeserverGeneric hs;
+            try {
+                if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a)
+                    hs = new AuthenticatedHomeserverMxApiExtended(homeserver, wellKnownUris, ref proxy, accessToken);
+                else {
+                    if (serverVersion is { Server.Name: "Synapse" })
+                        hs = new AuthenticatedHomeserverSynapse(homeserver, wellKnownUris, ref proxy, accessToken);
+                    else
+                        hs = new AuthenticatedHomeserverGeneric(homeserver, wellKnownUris, ref proxy, accessToken);
+                }
+            }
+            catch (Exception e) {
+                logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver);
+                throw;
+            }
 
-        lock (RemoteHomeserverCache) {
-            RemoteHomeserverCache[cacheKey] = hs;
-        }
+            await hs.Initialise();
 
-        sem.Release();
+            if (impersonatedMxid is not null)
+                await hs.SetImpersonate(impersonatedMxid);
 
-        return hs;
+            return hs;
+        });
     }
 
+    public async Task<RemoteHomeserver> GetRemoteHomeserver(string homeserver, string? proxy = null) =>
+        await RemoteHomeserverCache.GetOrAdd($"{homeserver}{proxy}", async () => {
+            return new RemoteHomeserver(homeserver, await hsResolver.ResolveHomeserverFromWellKnown(homeserver), ref proxy);
+        });
+
     public async Task<LoginResponse> Login(string homeserver, string user, string password, string? proxy = null) {
         var hs = await GetRemoteHomeserver(homeserver, proxy);
         var payload = new LoginRequest {