diff options
Diffstat (limited to '')
-rw-r--r-- | LibMatrix/Services/HomeserverProviderService.cs | 126 | ||||
-rw-r--r-- | LibMatrix/Services/HomeserverResolverService.cs | 169 | ||||
-rw-r--r-- | LibMatrix/Services/ServiceInstaller.cs | 3 |
3 files changed, 163 insertions, 135 deletions
diff --git a/LibMatrix/Services/HomeserverProviderService.cs b/LibMatrix/Services/HomeserverProviderService.cs index 8e2e15b..3995a26 100644 --- a/LibMatrix/Services/HomeserverProviderService.cs +++ b/LibMatrix/Services/HomeserverProviderService.cs @@ -1,4 +1,5 @@ using System.Net.Http.Json; +using ArcaneLibs.Collections; using ArcaneLibs.Extensions; using LibMatrix.Homeservers; using LibMatrix.Responses; @@ -6,99 +7,62 @@ using Microsoft.Extensions.Logging; namespace LibMatrix.Services; -public class HomeserverProviderService(ILogger<HomeserverProviderService> logger) { - private static readonly Dictionary<string, SemaphoreSlim> AuthenticatedHomeserverSemaphore = new(); - private static readonly Dictionary<string, AuthenticatedHomeserverGeneric> AuthenticatedHomeserverCache = new(); - - private static readonly Dictionary<string, SemaphoreSlim> RemoteHomeserverSemaphore = new(); - private static readonly Dictionary<string, RemoteHomeserver> RemoteHomeserverCache = new(); +public class HomeserverProviderService(ILogger<HomeserverProviderService> logger, HomeserverResolverService hsResolver) { + private static SemaphoreCache<AuthenticatedHomeserverGeneric> AuthenticatedHomeserverCache = new(); + private static SemaphoreCache<RemoteHomeserver> RemoteHomeserverCache = new(); public async Task<AuthenticatedHomeserverGeneric> GetAuthenticatedWithToken(string homeserver, string accessToken, string? proxy = null, string? impersonatedMxid = null) { - var cacheKey = homeserver + accessToken + proxy + impersonatedMxid; - var sem = AuthenticatedHomeserverSemaphore.GetOrCreate(cacheKey, _ => new SemaphoreSlim(1, 1)); - await sem.WaitAsync(); - AuthenticatedHomeserverGeneric? hs; - lock (AuthenticatedHomeserverCache) { - if (AuthenticatedHomeserverCache.TryGetValue(cacheKey, out hs)) { - sem.Release(); - return hs; - } - } - - var rhs = await RemoteHomeserver.Create(homeserver, proxy); - ClientVersionsResponse clientVersions = new(); - try { - clientVersions = await rhs.GetClientVersionsAsync(); - } - catch (Exception e) { - logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver); - } - - if (proxy is not null) - logger.LogInformation("Homeserver {homeserver} proxied via {proxy}...", homeserver, proxy); - logger.LogInformation("{homeserver}: {clientVersions}", homeserver, clientVersions.ToJson()); + return await AuthenticatedHomeserverCache.GetOrAdd($"{homeserver}{accessToken}{proxy}{impersonatedMxid}", async () => { + var wellKnownUris = await hsResolver.ResolveHomeserverFromWellKnown(homeserver); + var rhs = new RemoteHomeserver(homeserver, wellKnownUris, ref proxy); - ServerVersionResponse serverVersion; - try { - serverVersion = serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult<ServerVersionResponse?>(null)!); - } - catch (Exception e) { - logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver); - sem.Release(); - throw; - } - - try { - if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a) - hs = await AuthenticatedHomeserverGeneric.Create<AuthenticatedHomeserverMxApiExtended>(homeserver, accessToken, proxy); - else { - if (serverVersion is { Server.Name: "Synapse" }) - hs = await AuthenticatedHomeserverGeneric.Create<AuthenticatedHomeserverSynapse>(homeserver, accessToken, proxy); - else - hs = await AuthenticatedHomeserverGeneric.Create<AuthenticatedHomeserverGeneric>(homeserver, accessToken, proxy); + ClientVersionsResponse? clientVersions = new(); + try { + clientVersions = await rhs.GetClientVersionsAsync(); + } + catch (Exception e) { + logger.LogError(e, "Failed to get client versions for {homeserver}", homeserver); } - } - catch (Exception e) { - logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver); - sem.Release(); - throw; - } - - if (impersonatedMxid is not null) - await hs.SetImpersonate(impersonatedMxid); - - lock (AuthenticatedHomeserverCache) { - AuthenticatedHomeserverCache[cacheKey] = hs; - } - - sem.Release(); - - return hs; - } - public async Task<RemoteHomeserver> GetRemoteHomeserver(string homeserver, string? proxy = null) { - var cacheKey = homeserver + proxy; - var sem = RemoteHomeserverSemaphore.GetOrCreate(cacheKey, _ => new SemaphoreSlim(1, 1)); - await sem.WaitAsync(); - RemoteHomeserver? hs; - lock (RemoteHomeserverCache) { - if (RemoteHomeserverCache.TryGetValue(cacheKey, out hs)) { - sem.Release(); - return hs; + ServerVersionResponse? serverVersion; + try { + serverVersion = await (rhs.FederationClient?.GetServerVersionAsync() ?? Task.FromResult<ServerVersionResponse?>(null)!); + } + catch (Exception e) { + logger.LogWarning(e, "Failed to get server version for {homeserver}", homeserver); + throw; } - } - hs = await RemoteHomeserver.Create(homeserver, proxy); + AuthenticatedHomeserverGeneric hs; + try { + if (clientVersions.UnstableFeatures.TryGetValue("gay.rory.mxapiextensions.v0", out var a) && a) + hs = new AuthenticatedHomeserverMxApiExtended(homeserver, wellKnownUris, ref proxy, accessToken); + else { + if (serverVersion is { Server.Name: "Synapse" }) + hs = new AuthenticatedHomeserverSynapse(homeserver, wellKnownUris, ref proxy, accessToken); + else + hs = new AuthenticatedHomeserverGeneric(homeserver, wellKnownUris, ref proxy, accessToken); + } + } + catch (Exception e) { + logger.LogError(e, "Failed to create authenticated homeserver for {homeserver}", homeserver); + throw; + } - lock (RemoteHomeserverCache) { - RemoteHomeserverCache[cacheKey] = hs; - } + await hs.Initialise(); - sem.Release(); + if (impersonatedMxid is not null) + await hs.SetImpersonate(impersonatedMxid); - return hs; + return hs; + }); } + public async Task<RemoteHomeserver> GetRemoteHomeserver(string homeserver, string? proxy = null) => + await RemoteHomeserverCache.GetOrAdd($"{homeserver}{proxy}", async () => { + return new RemoteHomeserver(homeserver, await hsResolver.ResolveHomeserverFromWellKnown(homeserver), ref proxy); + }); + public async Task<LoginResponse> Login(string homeserver, string user, string password, string? proxy = null) { var hs = await GetRemoteHomeserver(homeserver, proxy); var payload = new LoginRequest { diff --git a/LibMatrix/Services/HomeserverResolverService.cs b/LibMatrix/Services/HomeserverResolverService.cs index bcef541..42ad0a1 100644 --- a/LibMatrix/Services/HomeserverResolverService.cs +++ b/LibMatrix/Services/HomeserverResolverService.cs @@ -1,87 +1,135 @@ using System.Collections.Concurrent; +using System.Diagnostics; +using System.Net.Http.Json; using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using ArcaneLibs.Collections; using ArcaneLibs.Extensions; using LibMatrix.Extensions; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; namespace LibMatrix.Services; -public class HomeserverResolverService(ILogger<HomeserverResolverService>? logger = null) { +public class HomeserverResolverService { private readonly MatrixHttpClient _httpClient = new() { Timeout = TimeSpan.FromMilliseconds(10000) }; - private static readonly ConcurrentDictionary<string, WellKnownUris> WellKnownCache = new(); - private static readonly ConcurrentDictionary<string, SemaphoreSlim> WellKnownSemaphores = new(); + private static readonly SemaphoreCache<WellKnownUris> WellKnownCache = new(); - public async Task<WellKnownUris> ResolveHomeserverFromWellKnown(string homeserver) { - if (homeserver is null) throw new ArgumentNullException(nameof(homeserver)); - WellKnownSemaphores.TryAdd(homeserver, new SemaphoreSlim(1, 1)); - await WellKnownSemaphores[homeserver].WaitAsync(); - if (WellKnownCache.TryGetValue(homeserver, out var known)) { - WellKnownSemaphores[homeserver].Release(); - return known; + private readonly ILogger<HomeserverResolverService> _logger; + + public HomeserverResolverService(ILogger<HomeserverResolverService> logger) { + _logger = logger; + if (logger is NullLogger<HomeserverResolverService>) { + var stackFrame = new StackTrace(true).GetFrame(1); + Console.WriteLine( + $"WARN | Null logger provided to HomeserverResolverService!\n{stackFrame.GetMethod().DeclaringType} at {stackFrame.GetFileName()}:{stackFrame.GetFileLineNumber()}"); } + } + + private static SemaphoreSlim _wellKnownSemaphore = new(1, 1); + + public async Task<WellKnownUris> ResolveHomeserverFromWellKnown(string homeserver) { + ArgumentNullException.ThrowIfNull(homeserver); - logger?.LogInformation("Resolving homeserver: {}", homeserver); - var res = new WellKnownUris { - Client = await _tryResolveFromClientWellknown(homeserver), - Server = await _tryResolveFromServerWellknown(homeserver) - }; - WellKnownCache.TryAdd(homeserver, res); - WellKnownSemaphores[homeserver].Release(); - return res; + return await WellKnownCache.GetOrAdd(homeserver, async () => { + await _wellKnownSemaphore.WaitAsync(); + _logger.LogTrace($"Resolving homeserver well-knowns: {homeserver}"); + var client = _tryResolveClientEndpoint(homeserver); + + var res = new WellKnownUris(); + + // try { + res.Client = await client ?? throw new Exception("Could not resolve client URL."); + // } + // catch (Exception e) { + // _logger.LogError(e, "Error resolving client well-known for {hs}", homeserver); + // } + + var server = _tryResolveServerEndpoint(homeserver); + + // try { + res.Server = await server ?? throw new Exception("Could not resolve server URL."); + // } + // catch (Exception e) { + // _logger.LogError(e, "Error resolving server well-known for {hs}", homeserver); + // } + + _logger.LogInformation("Resolved well-knowns for {hs}: {json}", homeserver, res.ToJson(indent: false)); + _wellKnownSemaphore.Release(); + return res; + }); } - private async Task<string?> _tryResolveFromClientWellknown(string homeserver) { - if (!homeserver.StartsWith("http")) { - if (await _httpClient.CheckSuccessStatus($"https://{homeserver}/.well-known/matrix/client")) - homeserver = "https://" + homeserver; - else if (await _httpClient.CheckSuccessStatus($"http://{homeserver}/.well-known/matrix/client")) { - homeserver = "http://" + homeserver; - } + private async Task<string?> _tryResolveClientEndpoint(string homeserver) { + ArgumentNullException.ThrowIfNull(homeserver); + _logger.LogTrace("Resolving client well-known: {homeserver}", homeserver); + ClientWellKnown? clientWellKnown = null; + // check if homeserver has a client well-known + if (homeserver.StartsWith("https://")) { + clientWellKnown = await _httpClient.TryGetFromJsonAsync<ClientWellKnown>($"{homeserver}/.well-known/matrix/client"); } - - try { - var resp = await _httpClient.GetFromJsonAsync<JsonElement>($"{homeserver}/.well-known/matrix/client"); - var hs = resp.GetProperty("m.homeserver").GetProperty("base_url").GetString(); - return hs; + else if (homeserver.StartsWith("http://")) { + clientWellKnown = await _httpClient.TryGetFromJsonAsync<ClientWellKnown>($"{homeserver}/.well-known/matrix/client"); } - catch { - // ignored + else { + clientWellKnown ??= await _httpClient.TryGetFromJsonAsync<ClientWellKnown>($"https://{homeserver}/.well-known/matrix/client"); + clientWellKnown ??= await _httpClient.TryGetFromJsonAsync<ClientWellKnown>($"http://{homeserver}/.well-known/matrix/client"); + + if (clientWellKnown is null) { + if (await _httpClient.CheckSuccessStatus($"https://{homeserver}/_matrix/client/versions")) + return $"https://{homeserver}"; + if (await _httpClient.CheckSuccessStatus($"http://{homeserver}/_matrix/client/versions")) + return $"http://{homeserver}"; + } } - logger?.LogInformation("No client well-known..."); + if (!string.IsNullOrWhiteSpace(clientWellKnown?.Homeserver.BaseUrl)) + return clientWellKnown.Homeserver.BaseUrl; + + _logger.LogInformation("No client well-known..."); return null; } - private async Task<string?> _tryResolveFromServerWellknown(string homeserver) { - if (!homeserver.StartsWith("http")) { - if (await _httpClient.CheckSuccessStatus($"https://{homeserver}/.well-known/matrix/server")) - homeserver = "https://" + homeserver; - else if (await _httpClient.CheckSuccessStatus($"http://{homeserver}/.well-known/matrix/server")) { - homeserver = "http://" + homeserver; - } + private async Task<string?> _tryResolveServerEndpoint(string homeserver) { + // TODO: implement SRV delegation via DoH: https://developers.google.com/speed/public-dns/docs/doh/json + ArgumentNullException.ThrowIfNull(homeserver); + _logger.LogTrace($"Resolving server well-known: {homeserver}"); + ServerWellKnown? serverWellKnown = null; + // check if homeserver has a server well-known + if (homeserver.StartsWith("https://")) { + serverWellKnown = await _httpClient.TryGetFromJsonAsync<ServerWellKnown>($"{homeserver}/.well-known/matrix/server"); } - - try { - var resp = await _httpClient.GetFromJsonAsync<JsonElement>($"{homeserver}/.well-known/matrix/server"); - var hs = resp.GetProperty("m.server").GetString(); - if (hs is null) throw new InvalidDataException("m.server is null"); - if (!hs.StartsWithAnyOf("http://", "https://")) - hs = $"https://{hs}"; - return hs; + else if (homeserver.StartsWith("http://")) { + serverWellKnown = await _httpClient.TryGetFromJsonAsync<ServerWellKnown>($"{homeserver}/.well-known/matrix/server"); + } + else { + serverWellKnown ??= await _httpClient.TryGetFromJsonAsync<ServerWellKnown>($"https://{homeserver}/.well-known/matrix/server"); + serverWellKnown ??= await _httpClient.TryGetFromJsonAsync<ServerWellKnown>($"http://{homeserver}/.well-known/matrix/server"); } - catch { - // ignored + + _logger.LogInformation("Server well-known for {hs}: {json}", homeserver, serverWellKnown?.ToJson() ?? "null"); + + if (!string.IsNullOrWhiteSpace(serverWellKnown?.Homeserver)) { + var resolved = serverWellKnown.Homeserver; + if (resolved.StartsWith("https://") || resolved.StartsWith("http://")) + return resolved; + if (await _httpClient.CheckSuccessStatus($"https://{resolved}/_matrix/federation/v1/version")) + return $"https://{resolved}"; + if (await _httpClient.CheckSuccessStatus($"http://{resolved}/_matrix/federation/v1/version")) + return $"http://{resolved}"; + _logger.LogWarning("Server well-known points to invalid server: {resolved}", resolved); } - // fallback: most servers host these on the same location - var clientUrl = await _tryResolveFromClientWellknown(homeserver); + // fallback: most servers host C2S and S2S on the same domain + var clientUrl = await _tryResolveClientEndpoint(homeserver); if (clientUrl is not null && await _httpClient.CheckSuccessStatus($"{clientUrl}/_matrix/federation/v1/version")) return clientUrl; - logger?.LogInformation("No server well-known..."); + _logger.LogInformation("No server well-known..."); return null; } @@ -97,4 +145,19 @@ public class HomeserverResolverService(ILogger<HomeserverResolverService>? logge public string? Client { get; set; } public string? Server { get; set; } } + + public class ClientWellKnown { + [JsonPropertyName("m.homeserver")] + public WellKnownHomeserver Homeserver { get; set; } + + public class WellKnownHomeserver { + [JsonPropertyName("base_url")] + public string BaseUrl { get; set; } + } + } + + public class ServerWellKnown { + [JsonPropertyName("m.server")] + public string Homeserver { get; set; } + } } \ No newline at end of file diff --git a/LibMatrix/Services/ServiceInstaller.cs b/LibMatrix/Services/ServiceInstaller.cs index 0f07b61..06ea9de 100644 --- a/LibMatrix/Services/ServiceInstaller.cs +++ b/LibMatrix/Services/ServiceInstaller.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; namespace LibMatrix.Services; @@ -11,7 +12,7 @@ public static class ServiceInstaller { services.AddSingleton(config ?? new RoryLibMatrixConfiguration()); //Add services - services.AddSingleton<HomeserverResolverService>(); + services.AddSingleton<HomeserverResolverService>(sp => new HomeserverResolverService(sp.GetRequiredService<ILogger<HomeserverResolverService>>())); // if (services.First(x => x.ServiceType == typeof(TieredStorageService)).Lifetime == ServiceLifetime.Singleton) { services.AddSingleton<HomeserverProviderService>(); |