diff options
author | Rory& <root@rory.gay> | 2024-01-31 18:50:16 +0100 |
---|---|---|
committer | Rory& <root@rory.gay> | 2024-01-31 18:50:16 +0100 |
commit | d133a168c17d605236e967ff63728420ddb0ed98 (patch) | |
tree | ac467c42591db6fb6b920682dd54da1f707f8c0c /ExampleBots/ModerationBot/PolicyEngine.cs | |
parent | Capabilities (diff) | |
download | LibMatrix-d133a168c17d605236e967ff63728420ddb0ed98.tar.xz |
Repo split
Diffstat (limited to 'ExampleBots/ModerationBot/PolicyEngine.cs')
-rw-r--r-- | ExampleBots/ModerationBot/PolicyEngine.cs | 268 |
1 files changed, 0 insertions, 268 deletions
diff --git a/ExampleBots/ModerationBot/PolicyEngine.cs b/ExampleBots/ModerationBot/PolicyEngine.cs deleted file mode 100644 index 0d0ed65..0000000 --- a/ExampleBots/ModerationBot/PolicyEngine.cs +++ /dev/null @@ -1,268 +0,0 @@ -using System.Diagnostics; -using System.Security.Cryptography; -using System.Text.Json; -using System.Text.RegularExpressions; -using ArcaneLibs.Extensions; -using LibMatrix; -using LibMatrix.EventTypes.Spec; -using LibMatrix.EventTypes.Spec.State.Policy; -using LibMatrix.Helpers; -using LibMatrix.Homeservers; -using LibMatrix.RoomTypes; -using LibMatrix.Services; -using Microsoft.Extensions.Logging; -using ModerationBot.AccountData; -using ModerationBot.StateEventTypes.Policies; -using ModerationBot.StateEventTypes.Policies.Implementations; - -namespace ModerationBot; - -public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationBot> logger, ModerationBotConfiguration configuration, HomeserverResolverService hsResolver) { - private Dictionary<string, PolicyList> PolicyListAccountData { get; set; } = new(); - // ReSharper disable once MemberCanBePrivate.Global - public List<PolicyList> ActivePolicyLists { get; set; } = new(); - public List<BasePolicy> ActivePolicies { get; set; } = new(); - public Dictionary<string, List<BasePolicy>> ActivePoliciesByType { get; set; } = new(); - private GenericRoom? _logRoom; - private GenericRoom? _controlRoom; - - public async Task ReloadActivePolicyLists() { - var sw = Stopwatch.StartNew(); - - var botData = await hs.GetAccountDataAsync<BotData>("gay.rory.moderation_bot_data"); - _logRoom ??= hs.GetRoom(botData.LogRoom ?? botData.ControlRoom); - _controlRoom ??= hs.GetRoom(botData.ControlRoom); - - await _controlRoom?.SendMessageEventAsync(MessageFormatter.FormatSuccess("Reloading policy lists!"))!; - await _logRoom?.SendMessageEventAsync(MessageFormatter.FormatSuccess("Reloading policy lists!"))!; - - var progressMessage = await _logRoom?.SendMessageEventAsync(MessageFormatter.FormatSuccess("0/? policy lists loaded"))!; - - var policyLists = new List<PolicyList>(); - try { - PolicyListAccountData = await hs.GetAccountDataAsync<Dictionary<string, PolicyList>>("gay.rory.moderation_bot.policy_lists"); - } - catch (MatrixException e) { - if (e is not { ErrorCode: "M_NOT_FOUND" }) throw; - } - - if (!PolicyListAccountData.ContainsKey(botData.DefaultPolicyRoom!)) { - PolicyListAccountData.Add(botData.DefaultPolicyRoom!, new PolicyList() { - Trusted = true - }); - await hs.SetAccountDataAsync("gay.rory.moderation_bot.policy_lists", PolicyListAccountData); - } - - var loadTasks = new List<Task<PolicyList>>(); - foreach (var (roomId, policyList) in PolicyListAccountData) { - var room = hs.GetRoom(roomId); - loadTasks.Add(LoadPolicyListAsync(room, policyList)); - } - - await foreach (var policyList in loadTasks.ToAsyncEnumerable()) { - policyLists.Add(policyList); - - if (false || policyList.Policies.Count >= 256 || policyLists.Count == PolicyListAccountData.Count) { - var progressMsgContent = MessageFormatter.FormatSuccess($"{policyLists.Count}/{PolicyListAccountData.Count} policy lists loaded, " + - $"{policyLists.Sum(x => x.Policies.Count)} policies total, {sw.Elapsed} elapsed.") - .SetReplaceRelation<RoomMessageEventContent>(progressMessage.EventId); - - await _logRoom?.SendMessageEventAsync(progressMsgContent); - } - } - - // Console.WriteLine($"Reloaded policy list data in {sw.Elapsed}"); - // await _logRoom.SendMessageEventAsync(MessageFormatter.FormatSuccess($"Done fetching {policyLists.Count} policy lists in {sw.Elapsed}!")); - - ActivePolicyLists = policyLists; - ActivePolicies = await GetActivePolicies(); - } - - private async Task<PolicyList> LoadPolicyListAsync(GenericRoom room, PolicyList policyList) { - policyList.Room = room; - policyList.Policies.Clear(); - - var stateEvents = room.GetFullStateAsync(); - await foreach (var stateEvent in stateEvents) { - if (stateEvent != null && ( - stateEvent.MappedType.IsAssignableTo(typeof(BasePolicy)) - || stateEvent.MappedType.IsAssignableTo(typeof(PolicyRuleEventContent)) - )) { - policyList.Policies.Add(stateEvent); - } - } - - // if (policyList.Policies.Count >= 1) - // await _logRoom?.SendMessageEventAsync( - // MessageFormatter.FormatSuccess($"Loaded {policyList.Policies.Count} policies for {MessageFormatter.HtmlFormatMention(room.RoomId)}!"))!; - - return policyList; - } - - - public async Task ReloadActivePolicyListById(string roomId) { - if (!ActivePolicyLists.Any(x => x.Room.RoomId == roomId)) return; - await LoadPolicyListAsync(hs.GetRoom(roomId), ActivePolicyLists.Single(x => x.Room.RoomId == roomId)); - ActivePolicies = await GetActivePolicies(); - } - - public async Task<List<BasePolicy>> GetActivePolicies() { - var sw = Stopwatch.StartNew(); - List<BasePolicy> activePolicies = new(); - - foreach (var activePolicyList in ActivePolicyLists) { - foreach (var policyEntry in activePolicyList.Policies) { - // TODO: implement rule translation - var policy = policyEntry.TypedContent is BasePolicy ? policyEntry.TypedContent as BasePolicy : policyEntry.RawContent.Deserialize<UnknownPolicy>(); - if (policy?.Entity is null) continue; - policy.PolicyList = activePolicyList; - policy.OriginalEvent = policyEntry; - activePolicies.Add(policy); - } - } - - Console.WriteLine($"Translated policy list data in {sw.Elapsed}"); - ActivePoliciesByType = activePolicies.GroupBy(x => x.GetType().Name).ToDictionary(x => x.Key, x => x.ToList()); - await _logRoom.SendMessageEventAsync(MessageFormatter.FormatSuccess($"Translated policy list data in {sw.GetElapsedAndRestart()}")); - // await _logRoom.SendMessageEventAsync(MessageFormatter.FormatSuccess($"Built policy type map in {sw.GetElapsedAndRestart()}")); - - var summary = SummariseStateTypeCounts(activePolicies.Select(x => x.OriginalEvent).ToList()); - await _logRoom?.SendMessageEventAsync(new RoomMessageEventContent() { - Body = summary.Raw, - FormattedBody = summary.Html, - Format = "org.matrix.custom.html" - })!; - - return activePolicies; - } - - public async Task<List<BasePolicy>> GetMatchingPolicies(StateEventResponse @event) { - List<BasePolicy> matchingPolicies = new(); - if (@event.Sender == @hs.UserId) return matchingPolicies; //ignore self at all costs - - if (ActivePoliciesByType.TryGetValue(nameof(ServerPolicyRuleEventContent), out var serverPolicies)) { - var userServer = @event.Sender.Split(':', 2)[1]; - matchingPolicies.AddRange(serverPolicies.Where(x => x.Entity == userServer)); - } - - if (ActivePoliciesByType.TryGetValue(nameof(UserPolicyRuleEventContent), out var userPolicies)) { - matchingPolicies.AddRange(userPolicies.Where(x => x.Entity == @event.Sender)); - } - - if (@event.TypedContent is RoomMessageEventContent msgContent) { - matchingPolicies.AddRange(await CheckMessageContent(@event)); - // if (msgContent.MessageType == "m.text" || msgContent.MessageType == "m.notice") ; //TODO: implement word etc. filters - if (msgContent.MessageType == "m.image" || msgContent.MessageType == "m.file" || msgContent.MessageType == "m.audio" || msgContent.MessageType == "m.video") - matchingPolicies.AddRange(await CheckMedia(@event)); - } - - return matchingPolicies; - } - - #region Policy matching - - private async Task<List<BasePolicy>> CheckMessageContent(StateEventResponse @event) { - var matchedRules = new List<BasePolicy>(); - var msgContent = @event.TypedContent as RoomMessageEventContent; - - if (ActivePoliciesByType.TryGetValue(nameof(MessagePolicyContainsText), out var messageContainsPolicies)) - foreach (var policy in messageContainsPolicies) { - if ((@msgContent?.Body?.ToLowerInvariant().Contains(policy.Entity.ToLowerInvariant()) ?? false) || (@msgContent?.FormattedBody?.ToLowerInvariant().Contains(policy.Entity.ToLowerInvariant()) ?? false)) - matchedRules.Add(policy); - } - - - return matchedRules; - } - - private async Task<List<BasePolicy>> CheckMedia(StateEventResponse @event) { - var matchedRules = new List<BasePolicy>(); - var hashAlgo = SHA3_256.Create(); - - var mxcUri = @event.RawContent["url"].GetValue<string>(); - - //check server policies before bothering with hashes - if (ActivePoliciesByType.TryGetValue(nameof(MediaPolicyHomeserver), out var mediaHomeserverPolicies)) - foreach (var policy in mediaHomeserverPolicies) { - logger.LogInformation("Checking rule {rule}: {data}", policy.OriginalEvent.StateKey, policy.OriginalEvent.TypedContent.ToJson(ignoreNull: true, indent: false)); - policy.Entity = policy.Entity.Replace("\\*", ".*").Replace("\\?", "."); - var regex = new Regex($"mxc://({policy.Entity})/.*", RegexOptions.Compiled | RegexOptions.IgnoreCase); - if (regex.IsMatch(@event.RawContent["url"]!.GetValue<string>())) { - logger.LogInformation("{url} matched rule {rule}", @event.RawContent["url"], policy.ToJson(ignoreNull: true)); - matchedRules.Add(policy); - // continue; - } - } - - var resolvedUri = await hsResolver.ResolveMediaUri(mxcUri.Split('/')[2], mxcUri); - var uriHash = hashAlgo.ComputeHash(mxcUri.AsBytes().ToArray()); - byte[]? fileHash = null; - - try { - fileHash = await hashAlgo.ComputeHashAsync(await hs.ClientHttpClient.GetStreamAsync(resolvedUri)); - } - catch (Exception ex) { - await _logRoom.SendMessageEventAsync( - MessageFormatter.FormatException($"Error calculating file hash for {mxcUri} via {mxcUri.Split('/')[2]} ({resolvedUri}), retrying via {hs.BaseUrl}...", - ex)); - try { - resolvedUri = await hsResolver.ResolveMediaUri(hs.BaseUrl, mxcUri); - fileHash = await hashAlgo.ComputeHashAsync(await hs.ClientHttpClient.GetStreamAsync(resolvedUri)); - } - catch (Exception ex2) { - await _logRoom.SendMessageEventAsync( - MessageFormatter.FormatException($"Error calculating file hash via {hs.BaseUrl} ({resolvedUri})!", ex2)); - } - } - - logger.LogInformation("Checking media {url} with hash {hash}", resolvedUri, fileHash); - - if (ActivePoliciesByType.ContainsKey(nameof(MediaPolicyFile))) - foreach (MediaPolicyFile policy in ActivePoliciesByType[nameof(MediaPolicyFile)]) { - logger.LogInformation("Checking rule {rule}: {data}", policy.OriginalEvent.StateKey, policy.OriginalEvent.TypedContent.ToJson(ignoreNull: true, indent: false)); - if (policy.Entity is not null && Convert.ToBase64String(uriHash).SequenceEqual(policy.Entity)) { - logger.LogInformation("{url} matched rule {rule} by uri hash", @event.RawContent["url"], policy.ToJson(ignoreNull: true)); - matchedRules.Add(policy); - // continue; - } - else logger.LogInformation("uri hash {uriHash} did not match rule's {ruleUriHash}", Convert.ToHexString(uriHash), policy.Entity); - - if (policy.FileHash is not null && fileHash is not null && policy.FileHash == Convert.ToBase64String(fileHash)) { - logger.LogInformation("{url} matched rule {rule} by file hash", @event.RawContent["url"], policy.ToJson(ignoreNull: true)); - matchedRules.Add(policy); - // continue; - } - else logger.LogInformation("file hash {fileHash} did not match rule's {ruleFileHash}", Convert.ToBase64String(fileHash), policy.FileHash); - - //check pixels every 10% of the way through the image using ImageSharp - // var image = Image.Load(await _hs._httpClient.GetStreamAsync(resolvedUri)); - } - else logger.LogInformation("No active media file policies"); - // logger.LogInformation("{url} did not match any rules", @event.RawContent["url"]); - - return matchedRules; - } - - #endregion - - #region Internal code - - #region Summarisation - - private static (string Raw, string Html) SummariseStateTypeCounts(IList<StateEventResponse> states) { - string raw = "Count | State type | Mapped type", html = "<table><tr><th>Count</th><th>State type</th><th>Mapped type</th></tr>"; - var groupedStates = states.GroupBy(x => x.Type).ToDictionary(x => x.Key, x => x.ToList()).OrderByDescending(x => x.Value.Count); - foreach (var (type, stateGroup) in groupedStates) { - raw += $"\n{stateGroup.Count} | {type} | {stateGroup[0].MappedType.Name}"; - html += $"<tr><td>{stateGroup.Count}</td><td>{type}</td><td>{stateGroup[0].MappedType.Name}</td></tr>"; - } - - html += "</table>"; - return (raw, html); - } - - #endregion - - #endregion - -} |