summary refs log tree commit diff
path: root/ModAS.Server/Authentication/AuthMiddleware.cs
diff options
context:
space:
mode:
Diffstat (limited to 'ModAS.Server/Authentication/AuthMiddleware.cs')
-rw-r--r--ModAS.Server/Authentication/AuthMiddleware.cs48
1 files changed, 30 insertions, 18 deletions
diff --git a/ModAS.Server/Authentication/AuthMiddleware.cs b/ModAS.Server/Authentication/AuthMiddleware.cs
index 8b7266f..11d4d39 100644
--- a/ModAS.Server/Authentication/AuthMiddleware.cs
+++ b/ModAS.Server/Authentication/AuthMiddleware.cs
@@ -1,3 +1,6 @@
+using System.Collections;
+using System.Collections.Concurrent;
+using System.Collections.Specialized;
 using System.Net.Http.Headers;
 using System.Text.Json;
 using LibMatrix;
@@ -23,14 +26,16 @@ public class AuthMiddleware(RequestDelegate next, ILogger<AuthMiddleware> logger
         }
 
         var authAttribute = endpoint?.Metadata.GetMetadata<UserAuthAttribute>();
-        if (authAttribute is not null)
-            logger.LogInformation($"{nameof(Route)} authorization: {authAttribute.ToJson()}");
-        else if (string.IsNullOrWhiteSpace(accessToken)) {
-            // auth is optional if auth attribute isnt set
-            Console.WriteLine($"Allowing unauthenticated request, AuthAttribute is not set!");
-            await next(context);
-            return;
+        if (authAttribute is null) {
+            if (string.IsNullOrWhiteSpace(accessToken)) {
+                // auth is optional if auth attribute isnt set
+                Console.WriteLine($"Allowing unauthenticated request, AuthAttribute is not set!");
+                await next(context);
+                return;
+            }
         }
+        else
+            logger.LogInformation($"{nameof(Route)} authorization: {authAttribute.ToJson()}");
 
         if (string.IsNullOrWhiteSpace(accessToken))
             if (authAttribute is not null) {
@@ -42,20 +47,27 @@ public class AuthMiddleware(RequestDelegate next, ILogger<AuthMiddleware> logger
                 return;
             }
 
+        if (await ValidateAuth(authAttribute, context, accessToken))
+            await next(context);
+    }
+
+    private async Task<bool> ValidateAuth(UserAuthAttribute? authAttribute, HttpContext context, string? accessToken) {
         try {
-            switch (authAttribute.AuthType) {
+            switch (authAttribute?.AuthType) {
+                case null:
                 case AuthType.User:
-                    var authUser = await GetAuthUser(accessToken);
+                    if (string.IsNullOrWhiteSpace(accessToken) && authAttribute is null)
+                        return true; //we dont care in this case
+                    var authUser = await GetAuthUser(accessToken!);
                     context.Items.Add("AuthUser", authUser);
-                    break;
+                    return true;
                 case AuthType.Server:
                     if (asr.HomeserverToken != accessToken)
                         throw new MatrixException() {
                             ErrorCode = "M_UNAUTHORIZED",
                             Error = "Invalid access token"
                         };
-
-                    break;
+                    return true;
                 default:
                     throw new ArgumentOutOfRangeException();
             }
@@ -63,17 +75,17 @@ public class AuthMiddleware(RequestDelegate next, ILogger<AuthMiddleware> logger
         catch (MatrixException e) {
             context.Response.StatusCode = 401;
             await context.Response.WriteAsJsonAsync(e.GetAsObject());
-            return;
+            return false;
         }
-
-        await next(context);
     }
 
+    private static readonly Dictionary<string, AuthUser> AuthCache = new();
+
     private async Task<AuthUser> GetAuthUser(string accessToken) {
-        AuthenticatedHomeserverGeneric? homeserver;
-        homeserver = await hsProvider.GetAuthenticatedWithToken(config.ServerName, accessToken, config.HomeserverUrl);
+        if (AuthCache.TryGetValue(accessToken, out var authUser)) return authUser;
+        var homeserver = await hsProvider.GetAuthenticatedWithToken(config.ServerName, accessToken, config.HomeserverUrl);
 
-        return new AuthUser() {
+        return AuthCache[accessToken] = new AuthUser() {
             Homeserver = homeserver,
             AccessToken = accessToken,
             Roles = config.Roles.Where(r => r.Value.Contains(homeserver.WhoAmI.UserId)).Select(r => r.Key).ToList()