From 37b97d65c0a5262539a5de560e911048166b8bba Mon Sep 17 00:00:00 2001 From: "Emma [it/its]@Rory&" Date: Fri, 5 Apr 2024 18:58:32 +0200 Subject: Fix homeserver resolution, rewrite homeserver initialisation, HSE work --- LibMatrix/Services/HomeserverProviderService.cs | 126 +++++++++--------------- 1 file changed, 45 insertions(+), 81 deletions(-) (limited to 'LibMatrix/Services/HomeserverProviderService.cs') diff --git a/LibMatrix/Services/HomeserverProviderService.cs b/LibMatrix/Services/HomeserverProviderService.cs index 8e2e15b..3995a26 100644 --- a/LibMatrix/Services/HomeserverProviderService.cs +++ b/LibMatrix/Services/HomeserverProviderService.cs @@ -1,4 +1,5 @@ using System.Net.Http.Json; +using ArcaneLibs.Collections; using ArcaneLibs.Extensions; using LibMatrix.Homeservers; using LibMatrix.Responses; @@ -6,99 +7,62 @@ using Microsoft.Extensions.Logging; namespace LibMatrix.Services; -public class HomeserverProviderService(ILogger logger) { - private static readonly Dictionary AuthenticatedHomeserverSemaphore = new(); - private static readonly Dictionary AuthenticatedHomeserverCache = new(); - - private static readonly Dictionary RemoteHomeserverSemaphore = new(); - private static readonly Dictionary RemoteHomeserverCache = new(); +public class HomeserverProviderService(ILogger logger, HomeserverResolverService hsResolver) { + private static SemaphoreCache AuthenticatedHomeserverCache = new(); + private static SemaphoreCache RemoteHomeserverCache = new(); public async Task GetAuthenticatedWithToken(string homeserver, string accessToken, string? proxy = null, string? impersonatedMxid = null) { - var cacheKey = homeserver + accessToken + proxy + impersonatedMxid; - var sem = AuthenticatedHomeserverSemaphore.GetOrCreate(cacheKey, _ => new SemaphoreSlim(1, 1)); - await sem.WaitAsync(); - AuthenticatedHomeserverGeneric? hs; - lock (AuthenticatedHomeserverCache) { - if (AuthenticatedHomeserverCache.TryGetValue(cacheKey, out hs)) { - sem.Release(); - return hs; - } - } - - var rhs = await RemoteHomeserver.Create(homeserver, proxy); - ClientVersionsResponse clientVersions = new(); - try { - clientVersions = await rhs.GetClientVersionsAsync(); - } - catch (Exception e) { - logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver); - } - - if (proxy is not null) - logger.LogInformation("Homeserver {homeserver} proxied via {proxy}...", homeserver, proxy); - logger.LogInformation("{homeserver}: {clientVersions}", homeserver, clientVersions.ToJson()); + return await AuthenticatedHomeserverCache.GetOrAdd($"{homeserver}{accessToken}{proxy}{impersonatedMxid}", async () => { + var wellKnownUris = await hsResolver.ResolveHomeserverFromWellKnown(homeserver); + var rhs = new RemoteHomeserver(homeserver, wellKnownUris, ref proxy); - ServerVersionResponse serverVersion; - try { - serverVersion = serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult(null)!); - } - catch (Exception e) { - logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver); - sem.Release(); - throw; - } - - try { - if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a) - hs = await AuthenticatedHomeserverGeneric.Create(homeserver, accessToken, proxy); - else { - if (serverVersion is { Server.Name: "Synapse" }) - hs = await AuthenticatedHomeserverGeneric.Create(homeserver, accessToken, proxy); - else - hs = await AuthenticatedHomeserverGeneric.Create(homeserver, accessToken, proxy); + ClientVersionsResponse? clientVersions = new(); + try { + clientVersions = await rhs.GetClientVersionsAsync(); + } + catch (Exception e) { + logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver); } - } - catch (Exception e) { - logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver); - sem.Release(); - throw; - } - - if (impersonatedMxid is not null) - await hs.SetImpersonate(impersonatedMxid); - - lock (AuthenticatedHomeserverCache) { - AuthenticatedHomeserverCache[cacheKey] = hs; - } - - sem.Release(); - - return hs; - } - public async Task GetRemoteHomeserver(string homeserver, string? proxy = null) { - var cacheKey = homeserver + proxy; - var sem = RemoteHomeserverSemaphore.GetOrCreate(cacheKey, _ => new SemaphoreSlim(1, 1)); - await sem.WaitAsync(); - RemoteHomeserver? hs; - lock (RemoteHomeserverCache) { - if (RemoteHomeserverCache.TryGetValue(cacheKey, out hs)) { - sem.Release(); - return hs; + ServerVersionResponse? serverVersion; + try { + serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult(null)!); + } + catch (Exception e) { + logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver); + throw; } - } - hs = await RemoteHomeserver.Create(homeserver, proxy); + AuthenticatedHomeserverGeneric hs; + try { + if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a) + hs = new AuthenticatedHomeserverMxApiExtended(homeserver, wellKnownUris, ref proxy, accessToken); + else { + if (serverVersion is { Server.Name: "Synapse" }) + hs = new AuthenticatedHomeserverSynapse(homeserver, wellKnownUris, ref proxy, accessToken); + else + hs = new AuthenticatedHomeserverGeneric(homeserver, wellKnownUris, ref proxy, accessToken); + } + } + catch (Exception e) { + logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver); + throw; + } - lock (RemoteHomeserverCache) { - RemoteHomeserverCache[cacheKey] = hs; - } + await hs.Initialise(); - sem.Release(); + if (impersonatedMxid is not null) + await hs.SetImpersonate(impersonatedMxid); - return hs; + return hs; + }); } + public async Task GetRemoteHomeserver(string homeserver, string? proxy = null) => + await RemoteHomeserverCache.GetOrAdd($"{homeserver}{proxy}", async () => { + return new RemoteHomeserver(homeserver, await hsResolver.ResolveHomeserverFromWellKnown(homeserver), ref proxy); + }); + public async Task Login(string homeserver, string user, string password, string? proxy = null) { var hs = await GetRemoteHomeserver(homeserver, proxy); var payload = new LoginRequest { -- cgit 1.4.1