about summary refs log tree commit diff
path: root/MatrixAntiDmSpam.Core/InviteManager.cs
blob: a48cea28c41c4e074e92f82aca6fbfc89a82ddb2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
using System.Diagnostics;
using System.Runtime.CompilerServices;
using ArcaneLibs.Attributes;
using ArcaneLibs.Extensions;
using LibMatrix;
using LibMatrix.EventTypes.Spec.State.Policy;
using LibMatrix.Helpers;
using LibMatrix.Homeservers;
using LibMatrix.RoomTypes;
using LibMatrix.Utilities.Bot.Interfaces;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;

namespace MatrixAntiDmSpam.Core;

public class InviteManager(
    ILogger<InviteManager> logger,
    AntiDmSpamConfiguration config,
    RoomInviteHandler roomInviteHandler,
    PolicyStore policyStore,
    AuthenticatedHomeserverGeneric homeserver) : IHostedService {
    private readonly GenericRoom? _logRoom = string.IsNullOrWhiteSpace(config.LogRoom) ? null : homeserver.GetRoom(config.LogRoom);
    public List<Func<RoomInviteContext, StateEventResponse, Task>> OnInviteRejected { get; } = [];
    public List<Func<RoomInviteContext, StateEventResponse, Task>> OnBeforeInviteRejected { get; } = [];

    public async Task StartAsync(CancellationToken cancellationToken) {
        roomInviteHandler.OnInviteReceived.Add(CheckPoliciesAgainstInvite);
        policyStore.OnPolicyAdded.Add(CheckPolicyAgainstOutstandingInvites);
    }

    public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask;

    private Task CheckPoliciesAgainstInvite(RoomInviteContext invite) {
        logger.LogInformation("Checking policies against invite");
        var sw = Stopwatch.StartNew();

        // Technically not required, but helps with scaling against millions of policies
        Parallel.ForEach(policyStore.AllPolicies.Values, (policy, loopState, idx) => {
            if (CheckPolicyAgainstInvite(invite, policy) is not null) {
                logger.LogInformation("Found matching policy after {} iterations ({})", idx, sw.Elapsed);
                loopState.Break();
            }
        });

        return Task.CompletedTask;
    }

    private async Task CheckPolicyAgainstOutstandingInvites(StateEventResponse newEvent) {
        var tasks = roomInviteHandler.Invites
            .Select(invite => CheckPolicyAgainstInvite(invite, newEvent))
            .Where(x => x is not null)
            .Cast<Task>() // from Task?
            .ToList();

        await Task.WhenAll(tasks);
    }

    private Task? CheckPolicyAgainstInvite(RoomInviteContext invite, StateEventResponse policyEvent) {
        var policy = policyEvent.TypedContent as PolicyRuleEventContent ?? throw new InvalidOperationException("Policy is null");
        if (policy.GetNormalizedRecommendation() is not ("m.ban" or "m.takedown")) return null;

        var policyMatches = false;
        switch (policy) {
            case UserPolicyRuleEventContent userPolicy:
                policyMatches = userPolicy.EntityMatches(invite.MemberEvent.Sender!);
                break;
            case ServerPolicyRuleEventContent serverPolicy:
                policyMatches = serverPolicy.EntityMatches(invite.MemberEvent.Sender!);
                break;
            case RoomPolicyRuleEventContent roomPolicy:
                policyMatches = roomPolicy.EntityMatches(invite.RoomId);
                break;
            default:
                if (_logRoom is not null)
                    _ = _logRoom.SendMessageEventAsync(new MessageBuilder().WithColoredBody("#FF0000", "Unknown policy type " + policy.GetType().FullName).Build());
                break;
        }

        if (!policyMatches) {
            logger.LogTrace("[{}] Policy {} does not match invite to {} by {}: {}", homeserver.WhoAmI.UserId, policy.GetType().GetFriendlyName(), invite.RoomId,
                invite.MemberEvent.Sender, policy.ToJson(ignoreNull: true, indent: false));
            return null;
        }

        return LogAndRejectInvite(invite, policyEvent, policy);
    }

    private async Task LogAndRejectInvite(RoomInviteContext invite, StateEventResponse policyEvent, PolicyRuleEventContent policy) {
        var policyRoom = config.PolicyLists.First(x => x.RoomId == policyEvent.RoomId);
        logger.LogWarning("[{}] Rejecting invite to {}, matching {} in {}: {}", homeserver.WhoAmI.UserId, invite.RoomId, policy.GetType().GetFriendlyName(),
            policyRoom.Name, policy.ToJson(ignoreNull: true));

        foreach (var callback in OnBeforeInviteRejected) {
            await callback(invite, policyEvent);
        }

        if (_logRoom is not null) {
            var roomName = await invite.TryGetRoomNameAsync();
            var logMessage = new MessageBuilder()
                .WithColoredBody("#FF0000",
                    cb => cb.WithBody("Rejecting invite to ").WithMention(invite.RoomId, roomName)
                        .WithBody($", matching {policy.GetType().GetFriendlyName().ToLowerInvariant()} in {policyRoom.Name}.")
                        .WithNewline())
                .WithCollapsibleSection("Policy JSON", cb => cb.WithCodeBlock(policy.ToJson(ignoreNull: true), "json"))
                .Build();
            await _logRoom.SendMessageEventAsync(logMessage);
        }

        await roomInviteHandler.RejectInvite(invite);

        foreach (var callback in OnInviteRejected) {
            await callback(invite, policyEvent);
        }
    }
}