1 files changed, 27 insertions, 10 deletions
diff --git a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
index dc8a8dc..e0f9db5 100644
--- a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
+++ b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
@@ -1,20 +1,37 @@
+using ArcaneLibs.Extensions;
using LibMatrix.Homeservers;
using LibMatrix.Services;
using MxApiExtensions.Classes.LibMatrix;
namespace MxApiExtensions.Services;
-public class AuthenticatedHomeserverProviderService {
- private readonly AuthenticationService _authenticationService;
- private readonly HomeserverProviderService _homeserverProviderService;
-
- public AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService) {
- _authenticationService = authenticationService;
- _homeserverProviderService = homeserverProviderService;
+public class AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService, IHttpContextAccessor request) {
+ public async Task<RemoteHomeserver?> TryGetRemoteHomeserver() {
+ try {
+ return await GetRemoteHomeserver();
+ }
+ catch {
+ return null;
+ }
+ }
+
+ public async Task<RemoteHomeserver> GetRemoteHomeserver() {
+ try {
+ return await GetHomeserver();
+ }
+ catch (MxApiMatrixException e) {
+ if (e is not { ErrorCode: "M_MISSING_TOKEN" }) throw;
+ if (!request.HttpContext!.Request.Headers.Keys.Any(x=>x.ToUpper() == "MXAE_UPSTREAM"))
+ throw new MxApiMatrixException() {
+ ErrorCode = "MXAE_MISSING_UPSTREAM",
+ Error = "[MxApiExtensions] Missing MXAE_UPSTREAM header for unauthenticated request, this should be a server_name!"
+ };
+ return await homeserverProviderService.GetRemoteHomeserver(request.HttpContext.Request.Headers.GetByCaseInsensitiveKey("MXAE_UPSTREAM")[0]);
+ }
}
public async Task<AuthenticatedHomeserverGeneric> GetHomeserver() {
- var token = _authenticationService.GetToken();
+ var token = authenticationService.GetToken();
if (token == null) {
throw new MxApiMatrixException {
ErrorCode = "M_MISSING_TOKEN",
@@ -22,7 +39,7 @@ public class AuthenticatedHomeserverProviderService {
};
}
- var mxid = await _authenticationService.GetMxidFromToken(token);
+ var mxid = await authenticationService.GetMxidFromToken(token);
if (mxid == "@anonymous:*") {
throw new MxApiMatrixException {
ErrorCode = "M_MISSING_TOKEN",
@@ -31,6 +48,6 @@ public class AuthenticatedHomeserverProviderService {
}
var hsCanonical = string.Join(":", mxid.Split(':').Skip(1));
- return await _homeserverProviderService.GetAuthenticatedWithToken(hsCanonical, token);
+ return await homeserverProviderService.GetAuthenticatedWithToken(hsCanonical, token);
}
}
|