diff options
Diffstat (limited to 'LibMatrix/Services/HomeserverProviderService.cs')
-rw-r--r-- | LibMatrix/Services/HomeserverProviderService.cs | 69 |
1 files changed, 36 insertions, 33 deletions
diff --git a/LibMatrix/Services/HomeserverProviderService.cs b/LibMatrix/Services/HomeserverProviderService.cs index 3995a26..c61ef73 100644 --- a/LibMatrix/Services/HomeserverProviderService.cs +++ b/LibMatrix/Services/HomeserverProviderService.cs @@ -11,43 +11,47 @@ public class HomeserverProviderService(ILogger<HomeserverProviderService> logger private static SemaphoreCache<AuthenticatedHomeserverGeneric> AuthenticatedHomeserverCache = new(); private static SemaphoreCache<RemoteHomeserver> RemoteHomeserverCache = new(); - public async Task<AuthenticatedHomeserverGeneric> GetAuthenticatedWithToken(string homeserver, string accessToken, string? proxy = null, string? impersonatedMxid = null) { + public async Task<AuthenticatedHomeserverGeneric> GetAuthenticatedWithToken(string homeserver, string accessToken, string? proxy = null, string? impersonatedMxid = null, + bool useGeneric = false) { return await AuthenticatedHomeserverCache.GetOrAdd($"{homeserver}{accessToken}{proxy}{impersonatedMxid}", async () => { var wellKnownUris = await hsResolver.ResolveHomeserverFromWellKnown(homeserver); var rhs = new RemoteHomeserver(homeserver, wellKnownUris, ref proxy); + + AuthenticatedHomeserverGeneric? hs = null; + if (!useGeneric) + { + ClientVersionsResponse? clientVersions = new(); + try { + clientVersions = await rhs.GetClientVersionsAsync(); + } + catch (Exception e) { + logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver); + } - ClientVersionsResponse? clientVersions = new(); - try { - clientVersions = await rhs.GetClientVersionsAsync(); - } - catch (Exception e) { - logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver); - } - - ServerVersionResponse? serverVersion; - try { - serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult<ServerVersionResponse?>(null)!); - } - catch (Exception e) { - logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver); - throw; - } + ServerVersionResponse? serverVersion; + try { + serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult<ServerVersionResponse?>(null)!); + } + catch (Exception e) { + logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver); + throw; + } - AuthenticatedHomeserverGeneric hs; - try { - if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a) - hs = new AuthenticatedHomeserverMxApiExtended(homeserver, wellKnownUris, ref proxy, accessToken); - else { - if (serverVersion is { Server.Name: "Synapse" }) - hs = new AuthenticatedHomeserverSynapse(homeserver, wellKnownUris, ref proxy, accessToken); - else - hs = new AuthenticatedHomeserverGeneric(homeserver, wellKnownUris, ref proxy, accessToken); + try { + if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a) + hs = new AuthenticatedHomeserverMxApiExtended(homeserver, wellKnownUris, ref proxy, accessToken); + else { + if (serverVersion is { Server.Name: "Synapse" }) + hs = new AuthenticatedHomeserverSynapse(homeserver, wellKnownUris, ref proxy, accessToken); + } + } + catch (Exception e) { + logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver); + throw; } } - catch (Exception e) { - logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver); - throw; - } + + hs ??= new AuthenticatedHomeserverGeneric(homeserver, wellKnownUris, ref proxy, accessToken); await hs.Initialise(); @@ -59,9 +63,8 @@ public class HomeserverProviderService(ILogger<HomeserverProviderService> logger } public async Task<RemoteHomeserver> GetRemoteHomeserver(string homeserver, string? proxy = null) => - await RemoteHomeserverCache.GetOrAdd($"{homeserver}{proxy}", async () => { - return new RemoteHomeserver(homeserver, await hsResolver.ResolveHomeserverFromWellKnown(homeserver), ref proxy); - }); + await RemoteHomeserverCache.GetOrAdd($"{homeserver}{proxy}", + async () => { return new RemoteHomeserver(homeserver, await hsResolver.ResolveHomeserverFromWellKnown(homeserver), ref proxy); }); public async Task<LoginResponse> Login(string homeserver, string user, string password, string? proxy = null) { var hs = await GetRemoteHomeserver(homeserver, proxy); |