about summary refs log tree commit diff
path: root/ExampleBots/ModerationBot/PolicyEngine.cs
diff options
context:
space:
mode:
Diffstat (limited to 'ExampleBots/ModerationBot/PolicyEngine.cs')
-rw-r--r--ExampleBots/ModerationBot/PolicyEngine.cs268
1 files changed, 0 insertions, 268 deletions
diff --git a/ExampleBots/ModerationBot/PolicyEngine.cs b/ExampleBots/ModerationBot/PolicyEngine.cs
deleted file mode 100644
index 0d0ed65..0000000
--- a/ExampleBots/ModerationBot/PolicyEngine.cs
+++ /dev/null
@@ -1,268 +0,0 @@
-using System.Diagnostics;
-using System.Security.Cryptography;
-using System.Text.Json;
-using System.Text.RegularExpressions;
-using ArcaneLibs.Extensions;
-using LibMatrix;
-using LibMatrix.EventTypes.Spec;
-using LibMatrix.EventTypes.Spec.State.Policy;
-using LibMatrix.Helpers;
-using LibMatrix.Homeservers;
-using LibMatrix.RoomTypes;
-using LibMatrix.Services;
-using Microsoft.Extensions.Logging;
-using ModerationBot.AccountData;
-using ModerationBot.StateEventTypes.Policies;
-using ModerationBot.StateEventTypes.Policies.Implementations;
-
-namespace ModerationBot;
-
-public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationBot> logger, ModerationBotConfiguration configuration, HomeserverResolverService hsResolver) {
-    private Dictionary<string, PolicyList> PolicyListAccountData { get; set; } = new();
-    // ReSharper disable once MemberCanBePrivate.Global
-    public List<PolicyList> ActivePolicyLists { get; set; } = new();
-    public List<BasePolicy> ActivePolicies { get; set; } = new();
-    public Dictionary<string, List<BasePolicy>> ActivePoliciesByType { get; set; } = new();
-    private GenericRoom? _logRoom;
-    private GenericRoom? _controlRoom;
-
-    public async Task ReloadActivePolicyLists() {
-        var sw = Stopwatch.StartNew();
-
-        var botData = await hs.GetAccountDataAsync<BotData>("gay.rory.moderation_bot_data");
-        _logRoom ??= hs.GetRoom(botData.LogRoom ?? botData.ControlRoom);
-        _controlRoom ??= hs.GetRoom(botData.ControlRoom);
-
-        await _controlRoom?.SendMessageEventAsync(MessageFormatter.FormatSuccess("Reloading policy lists!"))!;
-        await _logRoom?.SendMessageEventAsync(MessageFormatter.FormatSuccess("Reloading policy lists!"))!;
-
-        var progressMessage = await _logRoom?.SendMessageEventAsync(MessageFormatter.FormatSuccess("0/? policy lists loaded"))!;
-
-        var policyLists = new List<PolicyList>();
-        try {
-            PolicyListAccountData = await hs.GetAccountDataAsync<Dictionary<string, PolicyList>>("gay.rory.moderation_bot.policy_lists");
-        }
-        catch (MatrixException e) {
-            if (e is not { ErrorCode: "M_NOT_FOUND" }) throw;
-        }
-
-        if (!PolicyListAccountData.ContainsKey(botData.DefaultPolicyRoom!)) {
-            PolicyListAccountData.Add(botData.DefaultPolicyRoom!, new PolicyList() {
-                Trusted = true
-            });
-            await hs.SetAccountDataAsync("gay.rory.moderation_bot.policy_lists", PolicyListAccountData);
-        }
-
-        var loadTasks = new List<Task<PolicyList>>();
-        foreach (var (roomId, policyList) in PolicyListAccountData) {
-            var room = hs.GetRoom(roomId);
-            loadTasks.Add(LoadPolicyListAsync(room, policyList));
-        }
-
-        await foreach (var policyList in loadTasks.ToAsyncEnumerable()) {
-            policyLists.Add(policyList);
-
-            if (false || policyList.Policies.Count >= 256 || policyLists.Count == PolicyListAccountData.Count) {
-                var progressMsgContent = MessageFormatter.FormatSuccess($"{policyLists.Count}/{PolicyListAccountData.Count} policy lists loaded, " +
-                                                                        $"{policyLists.Sum(x => x.Policies.Count)} policies total, {sw.Elapsed} elapsed.")
-                    .SetReplaceRelation<RoomMessageEventContent>(progressMessage.EventId);
-
-                await _logRoom?.SendMessageEventAsync(progressMsgContent);
-            }
-        }
-
-        // Console.WriteLine($"Reloaded policy list data in {sw.Elapsed}");
-        // await _logRoom.SendMessageEventAsync(MessageFormatter.FormatSuccess($"Done fetching {policyLists.Count} policy lists in {sw.Elapsed}!"));
-
-        ActivePolicyLists = policyLists;
-        ActivePolicies = await GetActivePolicies();
-    }
-
-    private async Task<PolicyList> LoadPolicyListAsync(GenericRoom room, PolicyList policyList) {
-        policyList.Room = room;
-        policyList.Policies.Clear();
-
-        var stateEvents = room.GetFullStateAsync();
-        await foreach (var stateEvent in stateEvents) {
-            if (stateEvent != null && (
-                    stateEvent.MappedType.IsAssignableTo(typeof(BasePolicy))
-                    || stateEvent.MappedType.IsAssignableTo(typeof(PolicyRuleEventContent))
-                )) {
-                policyList.Policies.Add(stateEvent);
-            }
-        }
-
-        // if (policyList.Policies.Count >= 1)
-        // await _logRoom?.SendMessageEventAsync(
-        // MessageFormatter.FormatSuccess($"Loaded {policyList.Policies.Count} policies for {MessageFormatter.HtmlFormatMention(room.RoomId)}!"))!;
-
-        return policyList;
-    }
-
-
-    public async Task ReloadActivePolicyListById(string roomId) {
-        if (!ActivePolicyLists.Any(x => x.Room.RoomId == roomId)) return;
-        await LoadPolicyListAsync(hs.GetRoom(roomId), ActivePolicyLists.Single(x => x.Room.RoomId == roomId));
-        ActivePolicies = await GetActivePolicies();
-    }
-
-    public async Task<List<BasePolicy>> GetActivePolicies() {
-        var sw = Stopwatch.StartNew();
-        List<BasePolicy> activePolicies = new();
-
-        foreach (var activePolicyList in ActivePolicyLists) {
-            foreach (var policyEntry in activePolicyList.Policies) {
-                // TODO: implement rule translation
-                var policy = policyEntry.TypedContent is BasePolicy ? policyEntry.TypedContent as BasePolicy : policyEntry.RawContent.Deserialize<UnknownPolicy>();
-                if (policy?.Entity is null) continue;
-                policy.PolicyList = activePolicyList;
-                policy.OriginalEvent = policyEntry;
-                activePolicies.Add(policy);
-            }
-        }
-
-        Console.WriteLine($"Translated policy list data in {sw.Elapsed}");
-        ActivePoliciesByType = activePolicies.GroupBy(x => x.GetType().Name).ToDictionary(x => x.Key, x => x.ToList());
-        await _logRoom.SendMessageEventAsync(MessageFormatter.FormatSuccess($"Translated policy list data in {sw.GetElapsedAndRestart()}"));
-        // await _logRoom.SendMessageEventAsync(MessageFormatter.FormatSuccess($"Built policy type map in {sw.GetElapsedAndRestart()}"));
-
-        var summary = SummariseStateTypeCounts(activePolicies.Select(x => x.OriginalEvent).ToList());
-        await _logRoom?.SendMessageEventAsync(new RoomMessageEventContent() {
-            Body = summary.Raw,
-            FormattedBody = summary.Html,
-            Format = "org.matrix.custom.html"
-        })!;
-
-        return activePolicies;
-    }
-
-    public async Task<List<BasePolicy>> GetMatchingPolicies(StateEventResponse @event) {
-        List<BasePolicy> matchingPolicies = new();
-        if (@event.Sender == @hs.UserId) return matchingPolicies; //ignore self at all costs
-
-        if (ActivePoliciesByType.TryGetValue(nameof(ServerPolicyRuleEventContent), out var serverPolicies)) {
-            var userServer = @event.Sender.Split(':', 2)[1];
-            matchingPolicies.AddRange(serverPolicies.Where(x => x.Entity == userServer));
-        }
-
-        if (ActivePoliciesByType.TryGetValue(nameof(UserPolicyRuleEventContent), out var userPolicies)) {
-            matchingPolicies.AddRange(userPolicies.Where(x => x.Entity == @event.Sender));
-        }
-
-        if (@event.TypedContent is RoomMessageEventContent msgContent) {
-            matchingPolicies.AddRange(await CheckMessageContent(@event));
-            // if (msgContent.MessageType == "m.text" || msgContent.MessageType == "m.notice") ; //TODO: implement word etc. filters
-            if (msgContent.MessageType == "m.image" || msgContent.MessageType == "m.file" || msgContent.MessageType == "m.audio" || msgContent.MessageType == "m.video")
-                matchingPolicies.AddRange(await CheckMedia(@event));
-        }
-
-        return matchingPolicies;
-    }
-
-    #region Policy matching
-
-    private async Task<List<BasePolicy>> CheckMessageContent(StateEventResponse @event) {
-        var matchedRules = new List<BasePolicy>();
-        var msgContent = @event.TypedContent as RoomMessageEventContent;
-
-        if (ActivePoliciesByType.TryGetValue(nameof(MessagePolicyContainsText), out var messageContainsPolicies))
-            foreach (var policy in messageContainsPolicies) {
-                if ((@msgContent?.Body?.ToLowerInvariant().Contains(policy.Entity.ToLowerInvariant()) ?? false) || (@msgContent?.FormattedBody?.ToLowerInvariant().Contains(policy.Entity.ToLowerInvariant()) ?? false))
-                    matchedRules.Add(policy);
-            }
-
-
-        return matchedRules;
-    }
-
-    private async Task<List<BasePolicy>> CheckMedia(StateEventResponse @event) {
-        var matchedRules = new List<BasePolicy>();
-        var hashAlgo = SHA3_256.Create();
-
-        var mxcUri = @event.RawContent["url"].GetValue<string>();
-
-        //check server policies before bothering with hashes
-        if (ActivePoliciesByType.TryGetValue(nameof(MediaPolicyHomeserver), out var mediaHomeserverPolicies))
-            foreach (var policy in mediaHomeserverPolicies) {
-                logger.LogInformation("Checking rule {rule}: {data}", policy.OriginalEvent.StateKey, policy.OriginalEvent.TypedContent.ToJson(ignoreNull: true, indent: false));
-                policy.Entity = policy.Entity.Replace("\\*", ".*").Replace("\\?", ".");
-                var regex = new Regex($"mxc://({policy.Entity})/.*", RegexOptions.Compiled | RegexOptions.IgnoreCase);
-                if (regex.IsMatch(@event.RawContent["url"]!.GetValue<string>())) {
-                    logger.LogInformation("{url} matched rule {rule}", @event.RawContent["url"], policy.ToJson(ignoreNull: true));
-                    matchedRules.Add(policy);
-                    // continue;
-                }
-            }
-
-        var resolvedUri = await hsResolver.ResolveMediaUri(mxcUri.Split('/')[2], mxcUri);
-        var uriHash = hashAlgo.ComputeHash(mxcUri.AsBytes().ToArray());
-        byte[]? fileHash = null;
-
-        try {
-            fileHash = await hashAlgo.ComputeHashAsync(await hs.ClientHttpClient.GetStreamAsync(resolvedUri));
-        }
-        catch (Exception ex) {
-            await _logRoom.SendMessageEventAsync(
-                MessageFormatter.FormatException($"Error calculating file hash for {mxcUri} via {mxcUri.Split('/')[2]} ({resolvedUri}), retrying via {hs.BaseUrl}...",
-                    ex));
-            try {
-                resolvedUri = await hsResolver.ResolveMediaUri(hs.BaseUrl, mxcUri);
-                fileHash = await hashAlgo.ComputeHashAsync(await hs.ClientHttpClient.GetStreamAsync(resolvedUri));
-            }
-            catch (Exception ex2) {
-                await _logRoom.SendMessageEventAsync(
-                    MessageFormatter.FormatException($"Error calculating file hash via {hs.BaseUrl} ({resolvedUri})!", ex2));
-            }
-        }
-
-        logger.LogInformation("Checking media {url} with hash {hash}", resolvedUri, fileHash);
-
-        if (ActivePoliciesByType.ContainsKey(nameof(MediaPolicyFile)))
-            foreach (MediaPolicyFile policy in ActivePoliciesByType[nameof(MediaPolicyFile)]) {
-                logger.LogInformation("Checking rule {rule}: {data}", policy.OriginalEvent.StateKey, policy.OriginalEvent.TypedContent.ToJson(ignoreNull: true, indent: false));
-                if (policy.Entity is not null && Convert.ToBase64String(uriHash).SequenceEqual(policy.Entity)) {
-                    logger.LogInformation("{url} matched rule {rule} by uri hash", @event.RawContent["url"], policy.ToJson(ignoreNull: true));
-                    matchedRules.Add(policy);
-                    // continue;
-                }
-                else logger.LogInformation("uri hash {uriHash} did not match rule's {ruleUriHash}", Convert.ToHexString(uriHash), policy.Entity);
-
-                if (policy.FileHash is not null && fileHash is not null && policy.FileHash == Convert.ToBase64String(fileHash)) {
-                    logger.LogInformation("{url} matched rule {rule} by file hash", @event.RawContent["url"], policy.ToJson(ignoreNull: true));
-                    matchedRules.Add(policy);
-                    // continue;
-                }
-                else logger.LogInformation("file hash {fileHash} did not match rule's {ruleFileHash}", Convert.ToBase64String(fileHash), policy.FileHash);
-
-                //check pixels every 10% of the way through the image using ImageSharp
-                // var image = Image.Load(await _hs._httpClient.GetStreamAsync(resolvedUri));
-            }
-        else logger.LogInformation("No active media file policies");
-        // logger.LogInformation("{url} did not match any rules", @event.RawContent["url"]);
-
-        return matchedRules;
-    }
-
-    #endregion
-
-    #region Internal code
-
-    #region Summarisation
-
-    private static (string Raw, string Html) SummariseStateTypeCounts(IList<StateEventResponse> states) {
-        string raw = "Count | State type | Mapped type", html = "<table><tr><th>Count</th><th>State type</th><th>Mapped type</th></tr>";
-        var groupedStates = states.GroupBy(x => x.Type).ToDictionary(x => x.Key, x => x.ToList()).OrderByDescending(x => x.Value.Count);
-        foreach (var (type, stateGroup) in groupedStates) {
-            raw += $"\n{stateGroup.Count} | {type} | {stateGroup[0].MappedType.Name}";
-            html += $"<tr><td>{stateGroup.Count}</td><td>{type}</td><td>{stateGroup[0].MappedType.Name}</td></tr>";
-        }
-
-        html += "</table>";
-        return (raw, html);
-    }
-
-    #endregion
-
-    #endregion
-
-}