From 2abb132234546e61bb0aff3897dc49e72ea84f5d Mon Sep 17 00:00:00 2001 From: TheArcaneBrony Date: Sun, 5 Nov 2023 17:59:38 +0100 Subject: Working sync proxy --- .../AuthenticatedHomeserverProviderService.cs | 37 ++++++++++++++++------ MxApiExtensions/Services/AuthenticationService.cs | 26 +++++---------- 2 files changed, 35 insertions(+), 28 deletions(-) (limited to 'MxApiExtensions/Services') diff --git a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs index dc8a8dc..e0f9db5 100644 --- a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs +++ b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs @@ -1,20 +1,37 @@ +using ArcaneLibs.Extensions; using LibMatrix.Homeservers; using LibMatrix.Services; using MxApiExtensions.Classes.LibMatrix; namespace MxApiExtensions.Services; -public class AuthenticatedHomeserverProviderService { - private readonly AuthenticationService _authenticationService; - private readonly HomeserverProviderService _homeserverProviderService; - - public AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService) { - _authenticationService = authenticationService; - _homeserverProviderService = homeserverProviderService; +public class AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService, IHttpContextAccessor request) { + public async Task TryGetRemoteHomeserver() { + try { + return await GetRemoteHomeserver(); + } + catch { + return null; + } + } + + public async Task GetRemoteHomeserver() { + try { + return await GetHomeserver(); + } + catch (MxApiMatrixException e) { + if (e is not { ErrorCode: "M_MISSING_TOKEN" }) throw; + if (!request.HttpContext!.Request.Headers.Keys.Any(x=>x.ToUpper() == "MXAE_UPSTREAM")) + throw new MxApiMatrixException() { + ErrorCode = "MXAE_MISSING_UPSTREAM", + Error = "[MxApiExtensions] Missing MXAE_UPSTREAM header for unauthenticated request, this should be a server_name!" + }; + return await homeserverProviderService.GetRemoteHomeserver(request.HttpContext.Request.Headers.GetByCaseInsensitiveKey("MXAE_UPSTREAM")[0]); + } } public async Task GetHomeserver() { - var token = _authenticationService.GetToken(); + var token = authenticationService.GetToken(); if (token == null) { throw new MxApiMatrixException { ErrorCode = "M_MISSING_TOKEN", @@ -22,7 +39,7 @@ public class AuthenticatedHomeserverProviderService { }; } - var mxid = await _authenticationService.GetMxidFromToken(token); + var mxid = await authenticationService.GetMxidFromToken(token); if (mxid == "@anonymous:*") { throw new MxApiMatrixException { ErrorCode = "M_MISSING_TOKEN", @@ -31,6 +48,6 @@ public class AuthenticatedHomeserverProviderService { } var hsCanonical = string.Join(":", mxid.Split(':').Skip(1)); - return await _homeserverProviderService.GetAuthenticatedWithToken(hsCanonical, token); + return await homeserverProviderService.GetAuthenticatedWithToken(hsCanonical, token); } } diff --git a/MxApiExtensions/Services/AuthenticationService.cs b/MxApiExtensions/Services/AuthenticationService.cs index 9eac20a..0dcc8b1 100644 --- a/MxApiExtensions/Services/AuthenticationService.cs +++ b/MxApiExtensions/Services/AuthenticationService.cs @@ -3,21 +3,11 @@ using MxApiExtensions.Classes.LibMatrix; namespace MxApiExtensions.Services; -public class AuthenticationService { - private readonly ILogger _logger; - private readonly MxApiExtensionsConfiguration _config; - private readonly HomeserverProviderService _homeserverProviderService; - private readonly HttpRequest _request; +public class AuthenticationService(ILogger logger, MxApiExtensionsConfiguration config, IHttpContextAccessor request, HomeserverProviderService homeserverProviderService) { + private readonly HttpRequest _request = request.HttpContext!.Request; private static Dictionary _tokenMap = new(); - public AuthenticationService(ILogger logger, MxApiExtensionsConfiguration config, IHttpContextAccessor request, HomeserverProviderService homeserverProviderService) { - _logger = logger; - _config = config; - _homeserverProviderService = homeserverProviderService; - _request = request.HttpContext!.Request; - } - internal string? GetToken(bool fail = true) { string? token; if (_request.Headers.TryGetValue("Authorization", out var tokens)) { @@ -59,7 +49,7 @@ public class AuthenticationService { if (_tokenMap.TryGetValue(token, out var mxid)) return mxid; var lookupTasks = new Dictionary>(); - foreach (var homeserver in _config.AuthHomeservers) { + foreach (var homeserver in config.AuthHomeservers) { lookupTasks.Add(homeserver, GetMxidFromToken(token, homeserver)); await lookupTasks[homeserver].WaitAsync(TimeSpan.FromMilliseconds(250)); if(lookupTasks[homeserver].IsCompletedSuccessfully && !string.IsNullOrWhiteSpace(lookupTasks[homeserver].Result)) break; @@ -70,7 +60,7 @@ public class AuthenticationService { if(mxid is null) { throw new MxApiMatrixException { ErrorCode = "M_UNKNOWN_TOKEN", - Error = "Token not found on any configured homeservers: " + string.Join(", ", _config.AuthHomeservers) + Error = "Token not found on any configured homeservers: " + string.Join(", ", config.AuthHomeservers) }; } @@ -93,17 +83,17 @@ public class AuthenticationService { // // var json = (await JsonDocument.ParseAsync(await resp.Content.ReadAsStreamAsync())).RootElement; // var mxid = json.GetProperty("user_id").GetString()!; - _logger.LogInformation("Got mxid {} from token {}", mxid, token); + logger.LogInformation("Got mxid {} from token {}", mxid, token); await SaveMxidForToken(token, mxid); return mxid; } private async Task GetMxidFromToken(string token, string hsDomain) { - _logger.LogInformation("Looking up mxid for token {} on {}", token, hsDomain); - var hs = await _homeserverProviderService.GetAuthenticatedWithToken(hsDomain, token); + logger.LogInformation("Looking up mxid for token {} on {}", token, hsDomain); + var hs = await homeserverProviderService.GetAuthenticatedWithToken(hsDomain, token); try { var res = hs.WhoAmI.UserId; - _logger.LogInformation("Got mxid {} for token {} on {}", res, token, hsDomain); + logger.LogInformation("Got mxid {} for token {} on {}", res, token, hsDomain); return res; } catch (MxApiMatrixException e) { -- cgit 1.5.1