diff --git a/MiniUtils.Core/PolicyExecutor.cs b/MiniUtils.Core/PolicyExecutor.cs
new file mode 100644
index 0000000..80bbb99
--- /dev/null
+++ b/MiniUtils.Core/PolicyExecutor.cs
@@ -0,0 +1,230 @@
+using System.Diagnostics;
+using System.Text.Json.Nodes;
+using ArcaneLibs.Attributes;
+using ArcaneLibs.Extensions;
+using LibMatrix;
+using LibMatrix.EventTypes.Spec;
+using LibMatrix.EventTypes.Spec.State.Policy;
+using LibMatrix.Helpers;
+using LibMatrix.Homeservers;
+using LibMatrix.RoomTypes;
+using LibMatrix.Utilities.Bot.Interfaces;
+using MiniUtils.Core.Classes;
+using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
+
+namespace MiniUtils.Core;
+
+public class PolicyExecutor(
+ ILogger<PolicyExecutor> logger,
+ AntiDmSpamConfiguration config,
+ RoomInviteHandler roomInviteHandler,
+ PolicyStore policyStore,
+ AuthenticatedHomeserverGeneric homeserver) : IHostedService {
+ private readonly GenericRoom? _logRoom = string.IsNullOrWhiteSpace(config.LogRoom) ? null : homeserver.GetRoom(config.LogRoom);
+
+ public async Task StartAsync(CancellationToken cancellationToken) {
+ roomInviteHandler.OnInviteReceived.Add(CheckPoliciesAgainstInvite);
+ policyStore.OnPolicyAdded.Add(CheckPolicyAgainstOutstandingInvites);
+ if (config.IgnoreBannedUsers) {
+ await CleanupInvalidIgnoreListEntries();
+ policyStore.OnPoliciesChanged.Add(UpdateIgnoreList);
+ }
+ }
+
+ public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask;
+
+#region Feature: Manage ignore list
+
+ private async Task UpdateIgnoreList((
+ List<StateEventResponse> NewPolicies,
+ List<(StateEventResponse Old, StateEventResponse New)> UpdatedPolicies,
+ List<(StateEventResponse Old, StateEventResponse New)> RemovedPolicies) updates
+ ) {
+ var ignoreListContent = await FilterInvalidIgnoreListEntries();
+
+ foreach (var newEvent in updates.NewPolicies) {
+ var content = newEvent.TypedContent as PolicyRuleEventContent;
+ if (content.Entity is null || content.IsGlobRule()) continue;
+ if (content.GetNormalizedRecommendation() != "m.ban") continue;
+
+ var policyEventReference = new MadsIgnoreMetadataContent.PolicyEventReference() {
+ Type = newEvent.Type,
+ RoomId = newEvent.RoomId ?? throw new InvalidOperationException("RoomId is null"),
+ StateKey = newEvent.StateKey!
+ };
+
+ if (ignoreListContent.IgnoredUsers.TryGetValue(content.Entity, out var existingRule)) {
+ if (existingRule.AdditionalData?.ContainsKey(MadsIgnoreMetadataContent.EventId) ?? false) {
+ var existingMetadata = existingRule.GetAdditionalData<MadsIgnoreMetadataContent>(MadsIgnoreMetadataContent.EventId);
+ existingMetadata.Policies.Add(policyEventReference);
+ }
+ else {
+ existingRule.AdditionalData ??= new();
+ existingRule.AdditionalData.Add(MadsIgnoreMetadataContent.EventId, new MadsIgnoreMetadataContent {
+ WasUserAdded = true,
+ Policies = [policyEventReference]
+ });
+ }
+ }
+ else {
+ ignoreListContent.IgnoredUsers[content.Entity] = new() {
+ AdditionalData = new() {
+ [MadsIgnoreMetadataContent.EventId] = new MadsIgnoreMetadataContent {
+ WasUserAdded = false,
+ Policies = [policyEventReference]
+ }
+ }
+ };
+ }
+ }
+
+ foreach (var (previousEvent, newEvent) in updates.RemovedPolicies) {
+ if (previousEvent.Type != UserPolicyRuleEventContent.EventId) continue;
+ var previousContent = previousEvent.ContentAs<UserPolicyRuleEventContent>();
+
+ if (previousContent.Entity is null || previousContent.IsGlobRule()) continue;
+ if (previousContent.GetNormalizedRecommendation() != "m.ban") continue;
+
+ var ignoreList = await homeserver.GetIgnoredUserListAsync();
+ if (ignoreList.IgnoredUsers.TryGetValue(previousContent.Entity, out var existingRule)) {
+ if (existingRule.AdditionalData?.ContainsKey(MadsIgnoreMetadataContent.EventId) ?? false) {
+ var existingMetadata = existingRule.GetAdditionalData<MadsIgnoreMetadataContent>(MadsIgnoreMetadataContent.EventId);
+ existingMetadata.Policies.RemoveAll(x => x.Type == previousEvent.Type && x.RoomId == previousEvent.RoomId && x.StateKey == previousEvent.StateKey);
+ if (!existingMetadata.WasUserAdded)
+ ignoreList.IgnoredUsers.Remove(previousContent.Entity);
+ }
+ }
+ }
+
+ await homeserver.SetAccountDataAsync(IgnoredUserListEventContent.EventId, ignoreListContent);
+ }
+
+ private async Task<IgnoredUserListEventContent> FilterInvalidIgnoreListEntries() {
+ var ignoreList = await homeserver.GetAccountDataOrNullAsync<IgnoredUserListEventContent>(IgnoredUserListEventContent.EventId);
+ if (ignoreList != null) {
+ ignoreList.IgnoredUsers.RemoveAll((id, ignoredUserData) => {
+ if (ignoredUserData.AdditionalData is null) return false;
+ if (!ignoredUserData.AdditionalData.ContainsKey(MadsIgnoreMetadataContent.EventId)) return false;
+ var metadata = ignoredUserData.GetAdditionalData<JsonObject>(MadsIgnoreMetadataContent.EventId)!;
+
+ if (metadata.ContainsKey("policies")) {
+ var policies = metadata["policies"]!.AsArray();
+
+ bool IsPolicyEntryValid(JsonNode? p) =>
+ p!["room_id"]?.GetValue<string>() != null && p["type"]?.GetValue<string>() != null && p["state_key"]?.GetValue<string>() != null;
+
+ if (policies.Any(x => !IsPolicyEntryValid(x))) {
+ logger.LogWarning("Found invalid policy reference in ignore list, removing! {policy}",
+ policies.Where(x => !IsPolicyEntryValid(x)).Select(x => x.ToJson(ignoreNull: true)));
+ metadata["policies"] = new JsonArray(policies.Where(IsPolicyEntryValid).ToArray());
+ }
+ }
+
+ return metadata["was_user_added"]?.GetValue<bool>() is null or false;
+ });
+ }
+
+ return ignoreList;
+ }
+
+ private async Task CleanupInvalidIgnoreListEntries() {
+ var ignoreList = await FilterInvalidIgnoreListEntries();
+ ignoreList.IgnoredUsers.RemoveAll((id, _) => !(id.StartsWith('@') && id.Contains(':')));
+ List<string> idsToRemove = [];
+ foreach (var (id, ignoredUserData) in ignoreList.IgnoredUsers) {
+ if (ignoredUserData.AdditionalData is null) continue;
+ if (!ignoredUserData.AdditionalData.ContainsKey(MadsIgnoreMetadataContent.EventId)) continue;
+ try {
+ var metadata = ignoredUserData.GetAdditionalData<MadsIgnoreMetadataContent>(MadsIgnoreMetadataContent.EventId)!;
+
+ if (metadata.Policies.Count == 0 && !metadata.WasUserAdded) {
+ idsToRemove.Add(id);
+ }
+ }
+ catch (Exception e) {
+ logger.LogError(e, "Failed to parse ignore list entry for {}", id);
+ }
+ }
+
+ foreach (var id in idsToRemove) {
+ ignoreList.IgnoredUsers.Remove(id);
+ }
+ await homeserver.SetAccountDataAsync(IgnoredUserListEventContent.EventId, ignoreList);
+ }
+
+#endregion
+
+#region Feature: Report blocked invites
+
+#endregion
+
+#region Feature: Reject invites
+
+ private Task CheckPoliciesAgainstInvite(RoomInviteContext invite) {
+ logger.LogInformation("Checking policies against invite");
+ var sw = Stopwatch.StartNew();
+
+ // Technically not required, but helps with scaling against millions of policies
+ Parallel.ForEach(policyStore.AllPolicies.Values, (policy, loopState, idx) => {
+ if (CheckPolicyAgainstInvite(invite, policy) is not null) {
+ logger.LogInformation("Found matching policy after {} iterations ({})", idx, sw.Elapsed);
+ loopState.Break();
+ }
+ });
+
+ return Task.CompletedTask;
+ }
+
+ private async Task CheckPolicyAgainstOutstandingInvites(StateEventResponse newEvent) {
+ var tasks = roomInviteHandler.Invites
+ .Select(invite => CheckPolicyAgainstInvite(invite, newEvent))
+ .Where(x => x is not null)
+ .Cast<Task>() // from Task?
+ .ToList();
+
+ await Task.WhenAll(tasks);
+ }
+
+ private Task? CheckPolicyAgainstInvite(RoomInviteContext invite, StateEventResponse policyEvent) {
+ var policy = policyEvent.TypedContent as PolicyRuleEventContent ?? throw new InvalidOperationException("Policy is null");
+ if (policy.Recommendation != "m.ban") return null;
+
+ var policyMatches = false;
+ switch (policy) {
+ case UserPolicyRuleEventContent userPolicy:
+ policyMatches = userPolicy.EntityMatches(invite.MemberEvent.Sender!);
+ break;
+ case ServerPolicyRuleEventContent serverPolicy:
+ policyMatches = serverPolicy.EntityMatches(invite.MemberEvent.Sender!);
+ break;
+ case RoomPolicyRuleEventContent roomPolicy:
+ policyMatches = roomPolicy.EntityMatches(invite.RoomId);
+ break;
+ default:
+ if (_logRoom is not null)
+ _ = _logRoom.SendMessageEventAsync(new MessageBuilder().WithColoredBody("#FF0000", "Unknown policy type " + policy.GetType().FullName).Build());
+ break;
+ }
+
+ if (!policyMatches) return null;
+ logger.LogWarning("[{}] Rejecting invite to {}, matching {} {}", homeserver.WhoAmI.UserId, invite.RoomId, policy.GetType().GetFriendlyName(),
+ policy.ToJson(ignoreNull: true));
+
+ return Task.Run(async () => {
+ if (_logRoom is not null) {
+ string roomName = await invite.TryGetRoomNameAsync();
+
+ await roomInviteHandler.RejectInvite(invite, new MessageBuilder()
+ .WithColoredBody("#FF0000",
+ cb => cb.WithBody("Rejecting invite to ").WithMention(invite.RoomId, roomName)
+ .WithBody($", matching {policy.GetType().GetFriendlyName().ToLowerInvariant()}.")
+ .WithNewline())
+ .WithCollapsibleSection("Policy JSON", cb => cb.WithCodeBlock(policy.ToJson(ignoreNull: true), "json"))
+ );
+ }
+ });
+ }
+
+#endregion
+}
\ No newline at end of file
|