diff --git a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
index dc8a8dc..e0f9db5 100644
--- a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
+++ b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
@@ -1,20 +1,37 @@
+using ArcaneLibs.Extensions;
using LibMatrix.Homeservers;
using LibMatrix.Services;
using MxApiExtensions.Classes.LibMatrix;
namespace MxApiExtensions.Services;
-public class AuthenticatedHomeserverProviderService {
- private readonly AuthenticationService _authenticationService;
- private readonly HomeserverProviderService _homeserverProviderService;
-
- public AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService) {
- _authenticationService = authenticationService;
- _homeserverProviderService = homeserverProviderService;
+public class AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService, IHttpContextAccessor request) {
+ public async Task<RemoteHomeserver?> TryGetRemoteHomeserver() {
+ try {
+ return await GetRemoteHomeserver();
+ }
+ catch {
+ return null;
+ }
+ }
+
+ public async Task<RemoteHomeserver> GetRemoteHomeserver() {
+ try {
+ return await GetHomeserver();
+ }
+ catch (MxApiMatrixException e) {
+ if (e is not { ErrorCode: "M_MISSING_TOKEN" }) throw;
+ if (!request.HttpContext!.Request.Headers.Keys.Any(x=>x.ToUpper() == "MXAE_UPSTREAM"))
+ throw new MxApiMatrixException() {
+ ErrorCode = "MXAE_MISSING_UPSTREAM",
+ Error = "[MxApiExtensions] Missing MXAE_UPSTREAM header for unauthenticated request, this should be a server_name!"
+ };
+ return await homeserverProviderService.GetRemoteHomeserver(request.HttpContext.Request.Headers.GetByCaseInsensitiveKey("MXAE_UPSTREAM")[0]);
+ }
}
public async Task<AuthenticatedHomeserverGeneric> GetHomeserver() {
- var token = _authenticationService.GetToken();
+ var token = authenticationService.GetToken();
if (token == null) {
throw new MxApiMatrixException {
ErrorCode = "M_MISSING_TOKEN",
@@ -22,7 +39,7 @@ public class AuthenticatedHomeserverProviderService {
};
}
- var mxid = await _authenticationService.GetMxidFromToken(token);
+ var mxid = await authenticationService.GetMxidFromToken(token);
if (mxid == "@anonymous:*") {
throw new MxApiMatrixException {
ErrorCode = "M_MISSING_TOKEN",
@@ -31,6 +48,6 @@ public class AuthenticatedHomeserverProviderService {
}
var hsCanonical = string.Join(":", mxid.Split(':').Skip(1));
- return await _homeserverProviderService.GetAuthenticatedWithToken(hsCanonical, token);
+ return await homeserverProviderService.GetAuthenticatedWithToken(hsCanonical, token);
}
}
diff --git a/MxApiExtensions/Services/AuthenticationService.cs b/MxApiExtensions/Services/AuthenticationService.cs
index 9eac20a..0dcc8b1 100644
--- a/MxApiExtensions/Services/AuthenticationService.cs
+++ b/MxApiExtensions/Services/AuthenticationService.cs
@@ -3,21 +3,11 @@ using MxApiExtensions.Classes.LibMatrix;
namespace MxApiExtensions.Services;
-public class AuthenticationService {
- private readonly ILogger<AuthenticationService> _logger;
- private readonly MxApiExtensionsConfiguration _config;
- private readonly HomeserverProviderService _homeserverProviderService;
- private readonly HttpRequest _request;
+public class AuthenticationService(ILogger<AuthenticationService> logger, MxApiExtensionsConfiguration config, IHttpContextAccessor request, HomeserverProviderService homeserverProviderService) {
+ private readonly HttpRequest _request = request.HttpContext!.Request;
private static Dictionary<string, string> _tokenMap = new();
- public AuthenticationService(ILogger<AuthenticationService> logger, MxApiExtensionsConfiguration config, IHttpContextAccessor request, HomeserverProviderService homeserverProviderService) {
- _logger = logger;
- _config = config;
- _homeserverProviderService = homeserverProviderService;
- _request = request.HttpContext!.Request;
- }
-
internal string? GetToken(bool fail = true) {
string? token;
if (_request.Headers.TryGetValue("Authorization", out var tokens)) {
@@ -59,7 +49,7 @@ public class AuthenticationService {
if (_tokenMap.TryGetValue(token, out var mxid)) return mxid;
var lookupTasks = new Dictionary<string, Task<string?>>();
- foreach (var homeserver in _config.AuthHomeservers) {
+ foreach (var homeserver in config.AuthHomeservers) {
lookupTasks.Add(homeserver, GetMxidFromToken(token, homeserver));
await lookupTasks[homeserver].WaitAsync(TimeSpan.FromMilliseconds(250));
if(lookupTasks[homeserver].IsCompletedSuccessfully && !string.IsNullOrWhiteSpace(lookupTasks[homeserver].Result)) break;
@@ -70,7 +60,7 @@ public class AuthenticationService {
if(mxid is null) {
throw new MxApiMatrixException {
ErrorCode = "M_UNKNOWN_TOKEN",
- Error = "Token not found on any configured homeservers: " + string.Join(", ", _config.AuthHomeservers)
+ Error = "Token not found on any configured homeservers: " + string.Join(", ", config.AuthHomeservers)
};
}
@@ -93,17 +83,17 @@ public class AuthenticationService {
//
// var json = (await JsonDocument.ParseAsync(await resp.Content.ReadAsStreamAsync())).RootElement;
// var mxid = json.GetProperty("user_id").GetString()!;
- _logger.LogInformation("Got mxid {} from token {}", mxid, token);
+ logger.LogInformation("Got mxid {} from token {}", mxid, token);
await SaveMxidForToken(token, mxid);
return mxid;
}
private async Task<string?> GetMxidFromToken(string token, string hsDomain) {
- _logger.LogInformation("Looking up mxid for token {} on {}", token, hsDomain);
- var hs = await _homeserverProviderService.GetAuthenticatedWithToken(hsDomain, token);
+ logger.LogInformation("Looking up mxid for token {} on {}", token, hsDomain);
+ var hs = await homeserverProviderService.GetAuthenticatedWithToken(hsDomain, token);
try {
var res = hs.WhoAmI.UserId;
- _logger.LogInformation("Got mxid {} for token {} on {}", res, token, hsDomain);
+ logger.LogInformation("Got mxid {} for token {} on {}", res, token, hsDomain);
return res;
}
catch (MxApiMatrixException e) {
|