diff options
Diffstat (limited to 'ModAS.Server/Authentication/AuthMiddleware.cs')
-rw-r--r-- | ModAS.Server/Authentication/AuthMiddleware.cs | 48 |
1 files changed, 30 insertions, 18 deletions
diff --git a/ModAS.Server/Authentication/AuthMiddleware.cs b/ModAS.Server/Authentication/AuthMiddleware.cs index 8b7266f..11d4d39 100644 --- a/ModAS.Server/Authentication/AuthMiddleware.cs +++ b/ModAS.Server/Authentication/AuthMiddleware.cs @@ -1,3 +1,6 @@ +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Specialized; using System.Net.Http.Headers; using System.Text.Json; using LibMatrix; @@ -23,14 +26,16 @@ public class AuthMiddleware(RequestDelegate next, ILogger<AuthMiddleware> logger } var authAttribute = endpoint?.Metadata.GetMetadata<UserAuthAttribute>(); - if (authAttribute is not null) - logger.LogInformation($"{nameof(Route)} authorization: {authAttribute.ToJson()}"); - else if (string.IsNullOrWhiteSpace(accessToken)) { - // auth is optional if auth attribute isnt set - Console.WriteLine($"Allowing unauthenticated request, AuthAttribute is not set!"); - await next(context); - return; + if (authAttribute is null) { + if (string.IsNullOrWhiteSpace(accessToken)) { + // auth is optional if auth attribute isnt set + Console.WriteLine($"Allowing unauthenticated request, AuthAttribute is not set!"); + await next(context); + return; + } } + else + logger.LogInformation($"{nameof(Route)} authorization: {authAttribute.ToJson()}"); if (string.IsNullOrWhiteSpace(accessToken)) if (authAttribute is not null) { @@ -42,20 +47,27 @@ public class AuthMiddleware(RequestDelegate next, ILogger<AuthMiddleware> logger return; } + if (await ValidateAuth(authAttribute, context, accessToken)) + await next(context); + } + + private async Task<bool> ValidateAuth(UserAuthAttribute? authAttribute, HttpContext context, string? accessToken) { try { - switch (authAttribute.AuthType) { + switch (authAttribute?.AuthType) { + case null: case AuthType.User: - var authUser = await GetAuthUser(accessToken); + if (string.IsNullOrWhiteSpace(accessToken) && authAttribute is null) + return true; //we dont care in this case + var authUser = await GetAuthUser(accessToken!); context.Items.Add("AuthUser", authUser); - break; + return true; case AuthType.Server: if (asr.HomeserverToken != accessToken) throw new MatrixException() { ErrorCode = "M_UNAUTHORIZED", Error = "Invalid access token" }; - - break; + return true; default: throw new ArgumentOutOfRangeException(); } @@ -63,17 +75,17 @@ public class AuthMiddleware(RequestDelegate next, ILogger<AuthMiddleware> logger catch (MatrixException e) { context.Response.StatusCode = 401; await context.Response.WriteAsJsonAsync(e.GetAsObject()); - return; + return false; } - - await next(context); } + private static readonly Dictionary<string, AuthUser> AuthCache = new(); + private async Task<AuthUser> GetAuthUser(string accessToken) { - AuthenticatedHomeserverGeneric? homeserver; - homeserver = await hsProvider.GetAuthenticatedWithToken(config.ServerName, accessToken, config.HomeserverUrl); + if (AuthCache.TryGetValue(accessToken, out var authUser)) return authUser; + var homeserver = await hsProvider.GetAuthenticatedWithToken(config.ServerName, accessToken, config.HomeserverUrl); - return new AuthUser() { + return AuthCache[accessToken] = new AuthUser() { Homeserver = homeserver, AccessToken = accessToken, Roles = config.Roles.Where(r => r.Value.Contains(homeserver.WhoAmI.UserId)).Select(r => r.Key).ToList() |