diff --git a/ExampleBots/ModerationBot/PolicyEngine.cs b/ExampleBots/ModerationBot/PolicyEngine.cs
new file mode 100644
index 0000000..5311637
--- /dev/null
+++ b/ExampleBots/ModerationBot/PolicyEngine.cs
@@ -0,0 +1,269 @@
+using System.Diagnostics;
+using System.Security.Cryptography;
+using System.Text.Json;
+using System.Text.RegularExpressions;
+using ArcaneLibs.Extensions;
+using LibMatrix;
+using LibMatrix.EventTypes.Spec;
+using LibMatrix.EventTypes.Spec.State;
+using LibMatrix.Helpers;
+using LibMatrix.Homeservers;
+using LibMatrix.Interfaces;
+using LibMatrix.RoomTypes;
+using LibMatrix.Services;
+using ModerationBot.AccountData;
+using ModerationBot.StateEventTypes;
+using Microsoft.Extensions.Logging;
+using ModerationBot.StateEventTypes.Policies;
+using ModerationBot.StateEventTypes.Policies.Implementations;
+
+namespace ModerationBot;
+
+public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationBot> logger, ModerationBotConfiguration configuration, HomeserverResolverService hsResolver) {
+ private Dictionary<string, PolicyList> PolicyListAccountData { get; set; } = new();
+ public List<PolicyList> ActivePolicyLists { get; set; } = new();
+ public List<BasePolicy> ActivePolicies { get; set; } = new();
+ public Dictionary<string, List<BasePolicy>> ActivePoliciesByType { get; set; } = new();
+ private GenericRoom? _logRoom;
+ private GenericRoom? _controlRoom;
+
+ public async Task ReloadActivePolicyLists() {
+ var sw = Stopwatch.StartNew();
+
+ var botData = await hs.GetAccountDataAsync<BotData>("gay.rory.moderation_bot_data");
+ _logRoom ??= hs.GetRoom(botData.LogRoom ?? botData.ControlRoom);
+ _controlRoom ??= hs.GetRoom(botData.ControlRoom);
+
+ await _controlRoom?.SendMessageEventAsync(MessageFormatter.FormatSuccess("Reloading policy lists!"))!;
+ await _logRoom?.SendMessageEventAsync(MessageFormatter.FormatSuccess("Reloading policy lists!"))!;
+
+ var progressMessage = await _logRoom?.SendMessageEventAsync(MessageFormatter.FormatSuccess("0/? policy lists loaded"))!;
+
+ var policyLists = new List<PolicyList>();
+ try {
+ PolicyListAccountData = await hs.GetAccountDataAsync<Dictionary<string, PolicyList>>("gay.rory.moderation_bot.policy_lists");
+ }
+ catch (MatrixException e) {
+ if (e is not { ErrorCode: "M_NOT_FOUND" }) throw;
+ }
+
+ if (!PolicyListAccountData.ContainsKey(botData.DefaultPolicyRoom)) {
+ PolicyListAccountData.Add(botData.DefaultPolicyRoom, new PolicyList() {
+ Trusted = true
+ });
+ await hs.SetAccountDataAsync("gay.rory.moderation_bot.policy_lists", PolicyListAccountData);
+ }
+
+ var loadTasks = new List<Task<PolicyList>>();
+ foreach (var (roomId, policyList) in PolicyListAccountData) {
+ var room = hs.GetRoom(roomId);
+ loadTasks.Add(LoadPolicyListAsync(room, policyList));
+ }
+
+ await foreach (var policyList in loadTasks.ToAsyncEnumerable()) {
+ policyLists.Add(policyList);
+
+ if (policyList.Policies.Count >= 256 || policyLists.Count == PolicyListAccountData.Count) {
+ var progressMsgContent = MessageFormatter.FormatSuccess($"{policyLists.Count}/{PolicyListAccountData.Count} policy lists loaded, " +
+ $"{policyLists.Sum(x => x.Policies.Count)} policies total, {sw.Elapsed} elapsed.")
+ .SetReplaceRelation<RoomMessageEventContent>(progressMessage.EventId);
+
+ _logRoom?.SendMessageEventAsync(progressMsgContent);
+ }
+ }
+
+ // Console.WriteLine($"Reloaded policy list data in {sw.Elapsed}");
+ // await _logRoom.SendMessageEventAsync(MessageFormatter.FormatSuccess($"Done fetching {policyLists.Count} policy lists in {sw.Elapsed}!"));
+
+ ActivePolicyLists = policyLists;
+ ActivePolicies = await GetActivePolicies();
+ }
+
+ private async Task<PolicyList> LoadPolicyListAsync(GenericRoom room, PolicyList policyList) {
+ policyList.Room = room;
+ policyList.Policies.Clear();
+
+ var stateEvents = room.GetFullStateAsync();
+ await foreach (var stateEvent in stateEvents) {
+ if (stateEvent != null && (
+ stateEvent.GetType.IsAssignableTo(typeof(BasePolicy))
+ || stateEvent.GetType.IsAssignableTo(typeof(PolicyRuleEventContent))
+ )) {
+ policyList.Policies.Add(stateEvent);
+ }
+ }
+
+ // if (policyList.Policies.Count >= 1)
+ // await _logRoom?.SendMessageEventAsync(
+ // MessageFormatter.FormatSuccess($"Loaded {policyList.Policies.Count} policies for {MessageFormatter.HtmlFormatMention(room.RoomId)}!"))!;
+
+ return policyList;
+ }
+
+
+ public async Task ReloadActivePolicyListById(string roomId) {
+ if (!ActivePolicyLists.Any(x => x.Room.RoomId == roomId)) return;
+ await LoadPolicyListAsync(hs.GetRoom(roomId), ActivePolicyLists.Single(x => x.Room.RoomId == roomId));
+ ActivePolicies = await GetActivePolicies();
+ }
+
+ public async Task<List<BasePolicy>> GetActivePolicies() {
+ var sw = Stopwatch.StartNew();
+ List<BasePolicy> activePolicies = new();
+
+ foreach (var activePolicyList in ActivePolicyLists) {
+ foreach (var policyEntry in activePolicyList.Policies) {
+ // TODO: implement rule translation
+ BasePolicy policy = policyEntry.TypedContent is BasePolicy ? policyEntry.TypedContent as BasePolicy : policyEntry.RawContent.Deserialize<UnknownPolicy>();
+ if (policy.Entity is null) continue;
+ policy.PolicyList = activePolicyList;
+ policy.OriginalEvent = policyEntry;
+ activePolicies.Add(policy);
+ }
+ }
+
+ Console.WriteLine($"Translated policy list data in {sw.Elapsed}");
+ ActivePoliciesByType = activePolicies.GroupBy(x => x.GetType().Name).ToDictionary(x => x.Key, x => x.ToList());
+ await _logRoom.SendMessageEventAsync(MessageFormatter.FormatSuccess($"Translated policy list data in {sw.GetElapsedAndRestart()}"));
+ // await _logRoom.SendMessageEventAsync(MessageFormatter.FormatSuccess($"Built policy type map in {sw.GetElapsedAndRestart()}"));
+
+ var summary = SummariseStateTypeCounts(activePolicies.Select(x => x.OriginalEvent).ToList());
+ await _logRoom?.SendMessageEventAsync(new RoomMessageEventContent() {
+ Body = summary.Raw,
+ FormattedBody = summary.Html,
+ Format = "org.matrix.custom.html"
+ })!;
+
+ return activePolicies;
+ }
+
+ public async Task<List<BasePolicy>> GetMatchingPolicies(StateEventResponse @event) {
+ List<BasePolicy> matchingPolicies = new();
+ if (@event.Sender == @hs.UserId) return matchingPolicies; //ignore self at all costs
+
+ if (ActivePoliciesByType.TryGetValue(nameof(ServerPolicyRuleEventContent), out var serverPolicies)) {
+ var userServer = @event.Sender.Split(':', 2)[1];
+ matchingPolicies.AddRange(serverPolicies.Where(x => x.Entity == userServer));
+ }
+
+ if (ActivePoliciesByType.TryGetValue(nameof(UserPolicyRuleEventContent), out var userPolicies)) {
+ matchingPolicies.AddRange(userPolicies.Where(x => x.Entity == @event.Sender));
+ }
+
+ if (@event.TypedContent is RoomMessageEventContent msgContent) {
+ matchingPolicies.AddRange(await CheckMessageContent(@event));
+ if (msgContent.MessageType == "m.text" || msgContent.MessageType == "m.notice") ; //TODO: implement word etc. filters
+ if (msgContent.MessageType == "m.image" || msgContent.MessageType == "m.file" || msgContent.MessageType == "m.audio" || msgContent.MessageType == "m.video")
+ matchingPolicies.AddRange(await CheckMedia(@event));
+ }
+
+ return matchingPolicies;
+ }
+
+#region Policy matching
+
+ private async Task<List<BasePolicy>> CheckMessageContent(StateEventResponse @event) {
+ var matchedRules = new List<BasePolicy>();
+ var msgContent = @event.TypedContent as RoomMessageEventContent;
+
+ if (ActivePoliciesByType.TryGetValue(nameof(MessagePolicyContainsText), out var messageContainsPolicies))
+ foreach (var policy in messageContainsPolicies) {
+ if((@msgContent?.Body?.ToLowerInvariant().Contains(policy.Entity.ToLowerInvariant()) ?? false) || (@msgContent?.FormattedBody?.ToLowerInvariant().Contains(policy.Entity.ToLowerInvariant()) ?? false))
+ matchedRules.Add(policy);
+ }
+
+
+ return matchedRules;
+ }
+
+ private async Task<List<BasePolicy>> CheckMedia(StateEventResponse @event) {
+ var matchedRules = new List<BasePolicy>();
+ var hashAlgo = SHA3_256.Create();
+
+ var mxcUri = @event.RawContent["url"].GetValue<string>();
+
+ //check server policies before bothering with hashes
+ if (ActivePoliciesByType.TryGetValue(nameof(MediaPolicyHomeserver), out var mediaHomeserverPolicies))
+ foreach (var policy in mediaHomeserverPolicies) {
+ logger.LogInformation("Checking rule {rule}: {data}", policy.OriginalEvent.StateKey, policy.OriginalEvent.TypedContent.ToJson(ignoreNull: true, indent: false));
+ policy.Entity = policy.Entity.Replace("\\*", ".*").Replace("\\?", ".");
+ var regex = new Regex($"mxc://({policy.Entity})/.*", RegexOptions.Compiled | RegexOptions.IgnoreCase);
+ if (regex.IsMatch(@event.RawContent["url"]!.GetValue<string>())) {
+ logger.LogInformation("{url} matched rule {rule}", @event.RawContent["url"], policy.ToJson(ignoreNull: true));
+ matchedRules.Add(policy);
+ // continue;
+ }
+ }
+
+ var resolvedUri = await hsResolver.ResolveMediaUri(mxcUri.Split('/')[2], mxcUri);
+ var uriHash = hashAlgo.ComputeHash(mxcUri.AsBytes().ToArray());
+ byte[]? fileHash = null;
+
+ try {
+ fileHash = await hashAlgo.ComputeHashAsync(await hs.ClientHttpClient.GetStreamAsync(resolvedUri));
+ }
+ catch (Exception ex) {
+ await _logRoom.SendMessageEventAsync(
+ MessageFormatter.FormatException($"Error calculating file hash for {mxcUri} via {mxcUri.Split('/')[2]} ({resolvedUri}), retrying via {hs.BaseUrl}...",
+ ex));
+ try {
+ resolvedUri = await hsResolver.ResolveMediaUri(hs.BaseUrl, mxcUri);
+ fileHash = await hashAlgo.ComputeHashAsync(await hs.ClientHttpClient.GetStreamAsync(resolvedUri));
+ }
+ catch (Exception ex2) {
+ await _logRoom.SendMessageEventAsync(
+ MessageFormatter.FormatException($"Error calculating file hash via {hs.BaseUrl} ({resolvedUri})!", ex2));
+ }
+ }
+
+ logger.LogInformation("Checking media {url} with hash {hash}", resolvedUri, fileHash);
+
+ if (ActivePoliciesByType.ContainsKey(nameof(MediaPolicyFile)))
+ foreach (MediaPolicyFile policy in ActivePoliciesByType[nameof(MediaPolicyFile)]) {
+ logger.LogInformation("Checking rule {rule}: {data}", policy.OriginalEvent.StateKey, policy.OriginalEvent.TypedContent.ToJson(ignoreNull: true, indent: false));
+ if (policy.Entity is not null && Convert.ToBase64String(uriHash).SequenceEqual(policy.Entity)) {
+ logger.LogInformation("{url} matched rule {rule} by uri hash", @event.RawContent["url"], policy.ToJson(ignoreNull: true));
+ matchedRules.Add(policy);
+ // continue;
+ }
+ else logger.LogInformation("uri hash {uriHash} did not match rule's {ruleUriHash}", Convert.ToHexString(uriHash), policy.Entity);
+
+ if (policy.FileHash is not null && fileHash is not null && policy.FileHash == Convert.ToBase64String(fileHash)) {
+ logger.LogInformation("{url} matched rule {rule} by file hash", @event.RawContent["url"], policy.ToJson(ignoreNull: true));
+ matchedRules.Add(policy);
+ // continue;
+ }
+ else logger.LogInformation("file hash {fileHash} did not match rule's {ruleFileHash}", Convert.ToBase64String(fileHash), policy.FileHash);
+
+ //check pixels every 10% of the way through the image using ImageSharp
+ // var image = Image.Load(await _hs._httpClient.GetStreamAsync(resolvedUri));
+ }
+ else logger.LogInformation("No active media file policies");
+ // logger.LogInformation("{url} did not match any rules", @event.RawContent["url"]);
+
+ return matchedRules;
+ }
+
+#endregion
+
+#region Internal code
+
+#region Summarisation
+
+ private static (string Raw, string Html) SummariseStateTypeCounts(IList<StateEventResponse> states) {
+ string raw = "Count | State type | Mapped type", html = "<table><tr><th>Count</th><th>State type</th><th>Mapped type</th></tr>";
+ var groupedStates = states.GroupBy(x => x.Type).ToDictionary(x => x.Key, x => x.ToList()).OrderByDescending(x => x.Value.Count);
+ foreach (var (type, stateGroup) in groupedStates) {
+ raw += $"{stateGroup.Count} | {type} | {stateGroup[0].GetType.Name}";
+ html += $"<tr><td>{stateGroup.Count}</td><td>{type}</td><td>{stateGroup[0].GetType.Name}</td></tr>";
+ }
+
+ html += "</table>";
+ return (raw, html);
+ }
+
+#endregion
+
+#endregion
+
+}
\ No newline at end of file
|