about summary refs log tree commit diff
path: root/LibMatrix/Services/HomeserverProviderService.cs
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--LibMatrix/Services/HomeserverProviderService.cs69
1 files changed, 36 insertions, 33 deletions
diff --git a/LibMatrix/Services/HomeserverProviderService.cs b/LibMatrix/Services/HomeserverProviderService.cs
index 3995a26..c61ef73 100644
--- a/LibMatrix/Services/HomeserverProviderService.cs
+++ b/LibMatrix/Services/HomeserverProviderService.cs
@@ -11,43 +11,47 @@ public class HomeserverProviderService(ILogger<HomeserverProviderService> logger
     private static SemaphoreCache<AuthenticatedHomeserverGeneric> AuthenticatedHomeserverCache = new();
     private static SemaphoreCache<RemoteHomeserver> RemoteHomeserverCache = new();
 
-    public async Task<AuthenticatedHomeserverGeneric> GetAuthenticatedWithToken(string homeserver, string accessToken, string? proxy = null, string? impersonatedMxid = null) {
+    public async Task<AuthenticatedHomeserverGeneric> GetAuthenticatedWithToken(string homeserver, string accessToken, string? proxy = null, string? impersonatedMxid = null,
+        bool useGeneric = false) {
         return await AuthenticatedHomeserverCache.GetOrAdd($"{homeserver}{accessToken}{proxy}{impersonatedMxid}", async () => {
             var wellKnownUris = await hsResolver.ResolveHomeserverFromWellKnown(homeserver);
             var rhs = new RemoteHomeserver(homeserver, wellKnownUris, ref proxy);
+            
+            AuthenticatedHomeserverGeneric? hs = null;
+            if (!useGeneric)
+            {
+                ClientVersionsResponse? clientVersions = new();
+                try {
+                    clientVersions = await rhs.GetClientVersionsAsync();
+                }
+                catch (Exception e) {
+                    logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver);
+                }
 
-            ClientVersionsResponse? clientVersions = new();
-            try {
-                clientVersions = await rhs.GetClientVersionsAsync();
-            }
-            catch (Exception e) {
-                logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver);
-            }
-
-            ServerVersionResponse? serverVersion;
-            try {
-                serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult<ServerVersionResponse?>(null)!);
-            }
-            catch (Exception e) {
-                logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver);
-                throw;
-            }
+                ServerVersionResponse? serverVersion;
+                try {
+                    serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult<ServerVersionResponse?>(null)!);
+                }
+                catch (Exception e) {
+                    logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver);
+                    throw;
+                }
 
-            AuthenticatedHomeserverGeneric hs;
-            try {
-                if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a)
-                    hs = new AuthenticatedHomeserverMxApiExtended(homeserver, wellKnownUris, ref proxy, accessToken);
-                else {
-                    if (serverVersion is { Server.Name: "Synapse" })
-                        hs = new AuthenticatedHomeserverSynapse(homeserver, wellKnownUris, ref proxy, accessToken);
-                    else
-                        hs = new AuthenticatedHomeserverGeneric(homeserver, wellKnownUris, ref proxy, accessToken);
+                try {
+                    if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a)
+                        hs = new AuthenticatedHomeserverMxApiExtended(homeserver, wellKnownUris, ref proxy, accessToken);
+                    else {
+                        if (serverVersion is { Server.Name: "Synapse" })
+                            hs = new AuthenticatedHomeserverSynapse(homeserver, wellKnownUris, ref proxy, accessToken);
+                    }
+                }
+                catch (Exception e) {
+                    logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver);
+                    throw;
                 }
             }
-            catch (Exception e) {
-                logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver);
-                throw;
-            }
+            
+            hs ??= new AuthenticatedHomeserverGeneric(homeserver, wellKnownUris, ref proxy, accessToken);
 
             await hs.Initialise();
 
@@ -59,9 +63,8 @@ public class HomeserverProviderService(ILogger<HomeserverProviderService> logger
     }
 
     public async Task<RemoteHomeserver> GetRemoteHomeserver(string homeserver, string? proxy = null) =>
-        await RemoteHomeserverCache.GetOrAdd($"{homeserver}{proxy}", async () => {
-            return new RemoteHomeserver(homeserver, await hsResolver.ResolveHomeserverFromWellKnown(homeserver), ref proxy);
-        });
+        await RemoteHomeserverCache.GetOrAdd($"{homeserver}{proxy}",
+            async () => { return new RemoteHomeserver(homeserver, await hsResolver.ResolveHomeserverFromWellKnown(homeserver), ref proxy); });
 
     public async Task<LoginResponse> Login(string homeserver, string user, string password, string? proxy = null) {
         var hs = await GetRemoteHomeserver(homeserver, proxy);