diff --git a/ModAS.Server/Authentication/AuthMiddleware.cs b/ModAS.Server/Authentication/AuthMiddleware.cs
index 8b7266f..11d4d39 100644
--- a/ModAS.Server/Authentication/AuthMiddleware.cs
+++ b/ModAS.Server/Authentication/AuthMiddleware.cs
@@ -1,3 +1,6 @@
+using System.Collections;
+using System.Collections.Concurrent;
+using System.Collections.Specialized;
using System.Net.Http.Headers;
using System.Text.Json;
using LibMatrix;
@@ -23,14 +26,16 @@ public class AuthMiddleware(RequestDelegate next, ILogger<AuthMiddleware> logger
}
var authAttribute = endpoint?.Metadata.GetMetadata<UserAuthAttribute>();
- if (authAttribute is not null)
- logger.LogInformation($"{nameof(Route)} authorization: {authAttribute.ToJson()}");
- else if (string.IsNullOrWhiteSpace(accessToken)) {
- // auth is optional if auth attribute isnt set
- Console.WriteLine($"Allowing unauthenticated request, AuthAttribute is not set!");
- await next(context);
- return;
+ if (authAttribute is null) {
+ if (string.IsNullOrWhiteSpace(accessToken)) {
+ // auth is optional if auth attribute isnt set
+ Console.WriteLine($"Allowing unauthenticated request, AuthAttribute is not set!");
+ await next(context);
+ return;
+ }
}
+ else
+ logger.LogInformation($"{nameof(Route)} authorization: {authAttribute.ToJson()}");
if (string.IsNullOrWhiteSpace(accessToken))
if (authAttribute is not null) {
@@ -42,20 +47,27 @@ public class AuthMiddleware(RequestDelegate next, ILogger<AuthMiddleware> logger
return;
}
+ if (await ValidateAuth(authAttribute, context, accessToken))
+ await next(context);
+ }
+
+ private async Task<bool> ValidateAuth(UserAuthAttribute? authAttribute, HttpContext context, string? accessToken) {
try {
- switch (authAttribute.AuthType) {
+ switch (authAttribute?.AuthType) {
+ case null:
case AuthType.User:
- var authUser = await GetAuthUser(accessToken);
+ if (string.IsNullOrWhiteSpace(accessToken) && authAttribute is null)
+ return true; //we dont care in this case
+ var authUser = await GetAuthUser(accessToken!);
context.Items.Add("AuthUser", authUser);
- break;
+ return true;
case AuthType.Server:
if (asr.HomeserverToken != accessToken)
throw new MatrixException() {
ErrorCode = "M_UNAUTHORIZED",
Error = "Invalid access token"
};
-
- break;
+ return true;
default:
throw new ArgumentOutOfRangeException();
}
@@ -63,17 +75,17 @@ public class AuthMiddleware(RequestDelegate next, ILogger<AuthMiddleware> logger
catch (MatrixException e) {
context.Response.StatusCode = 401;
await context.Response.WriteAsJsonAsync(e.GetAsObject());
- return;
+ return false;
}
-
- await next(context);
}
+ private static readonly Dictionary<string, AuthUser> AuthCache = new();
+
private async Task<AuthUser> GetAuthUser(string accessToken) {
- AuthenticatedHomeserverGeneric? homeserver;
- homeserver = await hsProvider.GetAuthenticatedWithToken(config.ServerName, accessToken, config.HomeserverUrl);
+ if (AuthCache.TryGetValue(accessToken, out var authUser)) return authUser;
+ var homeserver = await hsProvider.GetAuthenticatedWithToken(config.ServerName, accessToken, config.HomeserverUrl);
- return new AuthUser() {
+ return AuthCache[accessToken] = new AuthUser() {
Homeserver = homeserver,
AccessToken = accessToken,
Roles = config.Roles.Where(r => r.Value.Contains(homeserver.WhoAmI.UserId)).Select(r => r.Key).ToList()
|