summary refs log tree commit diff
path: root/ModAS.Server/Authentication/AuthMiddleware.cs
blob: 11d4d39f2eb056e06e2c6e344582dfe905f83b3b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Specialized;
using System.Net.Http.Headers;
using System.Text.Json;
using LibMatrix;
using LibMatrix.Homeservers;
using LibMatrix.Services;
using ModAS.Server.Attributes;
using MxApiExtensions.Services;

namespace ModAS.Server.Authentication;

public class AuthMiddleware(RequestDelegate next, ILogger<AuthMiddleware> logger, ModASConfiguration config, HomeserverProviderService hsProvider, AppServiceRegistration asr) {
    public async Task InvokeAsync(HttpContext context) {
        context.Request.Query.TryGetValue("access_token", out var queryAccessToken);
        var accessToken = queryAccessToken.FirstOrDefault();
        accessToken ??= context.Request.GetTypedHeaders().Get<AuthenticationHeaderValue>("Authorization")?.Parameter;

        //get UserAuth custom attribute
        var endpoint = context.GetEndpoint();
        if (endpoint is null) {
            Console.WriteLine($"Ignoring authentication, endpoint is null!");
            await next(context);
            return;
        }

        var authAttribute = endpoint?.Metadata.GetMetadata<UserAuthAttribute>();
        if (authAttribute is null) {
            if (string.IsNullOrWhiteSpace(accessToken)) {
                // auth is optional if auth attribute isnt set
                Console.WriteLine($"Allowing unauthenticated request, AuthAttribute is not set!");
                await next(context);
                return;
            }
        }
        else
            logger.LogInformation($"{nameof(Route)} authorization: {authAttribute.ToJson()}");

        if (string.IsNullOrWhiteSpace(accessToken))
            if (authAttribute is not null) {
                context.Response.StatusCode = 401;
                await context.Response.WriteAsJsonAsync(new MatrixException() {
                    ErrorCode = "M_UNAUTHORIZED",
                    Error = "Missing access token"
                }.GetAsObject());
                return;
            }

        if (await ValidateAuth(authAttribute, context, accessToken))
            await next(context);
    }

    private async Task<bool> ValidateAuth(UserAuthAttribute? authAttribute, HttpContext context, string? accessToken) {
        try {
            switch (authAttribute?.AuthType) {
                case null:
                case AuthType.User:
                    if (string.IsNullOrWhiteSpace(accessToken) && authAttribute is null)
                        return true; //we dont care in this case
                    var authUser = await GetAuthUser(accessToken!);
                    context.Items.Add("AuthUser", authUser);
                    return true;
                case AuthType.Server:
                    if (asr.HomeserverToken != accessToken)
                        throw new MatrixException() {
                            ErrorCode = "M_UNAUTHORIZED",
                            Error = "Invalid access token"
                        };
                    return true;
                default:
                    throw new ArgumentOutOfRangeException();
            }
        }
        catch (MatrixException e) {
            context.Response.StatusCode = 401;
            await context.Response.WriteAsJsonAsync(e.GetAsObject());
            return false;
        }
    }

    private static readonly Dictionary<string, AuthUser> AuthCache = new();

    private async Task<AuthUser> GetAuthUser(string accessToken) {
        if (AuthCache.TryGetValue(accessToken, out var authUser)) return authUser;
        var homeserver = await hsProvider.GetAuthenticatedWithToken(config.ServerName, accessToken, config.HomeserverUrl);

        return AuthCache[accessToken] = new AuthUser() {
            Homeserver = homeserver,
            AccessToken = accessToken,
            Roles = config.Roles.Where(r => r.Value.Contains(homeserver.WhoAmI.UserId)).Select(r => r.Key).ToList()
        };
    }
}