using System.Diagnostics; using System.Text.Json.Nodes; using ArcaneLibs.Attributes; using ArcaneLibs.Extensions; using LibMatrix; using LibMatrix.EventTypes.Spec; using LibMatrix.EventTypes.Spec.State.Policy; using LibMatrix.Helpers; using LibMatrix.Homeservers; using LibMatrix.RoomTypes; using LibMatrix.Utilities.Bot.Interfaces; using MiniUtils.Core.Classes; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; namespace MiniUtils.Core; public class PolicyExecutor( ILogger logger, AntiDmSpamConfiguration config, RoomInviteHandler roomInviteHandler, PolicyStore policyStore, AuthenticatedHomeserverGeneric homeserver) : IHostedService { private readonly GenericRoom? _logRoom = string.IsNullOrWhiteSpace(config.LogRoom) ? null : homeserver.GetRoom(config.LogRoom); public async Task StartAsync(CancellationToken cancellationToken) { roomInviteHandler.OnInviteReceived.Add(CheckPoliciesAgainstInvite); policyStore.OnPolicyAdded.Add(CheckPolicyAgainstOutstandingInvites); if (config.IgnoreBannedUsers) { await CleanupInvalidIgnoreListEntries(); policyStore.OnPoliciesChanged.Add(UpdateIgnoreList); } } public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask; #region Feature: Manage ignore list private async Task UpdateIgnoreList(( List NewPolicies, List<(StateEventResponse Old, StateEventResponse New)> UpdatedPolicies, List<(StateEventResponse Old, StateEventResponse New)> RemovedPolicies) updates ) { var ignoreListContent = await FilterInvalidIgnoreListEntries(); foreach (var newEvent in updates.NewPolicies) { var content = newEvent.TypedContent as PolicyRuleEventContent; if (content.Entity is null || content.IsGlobRule()) continue; if (content.GetNormalizedRecommendation() != "m.ban") continue; var policyEventReference = new MadsIgnoreMetadataContent.PolicyEventReference() { Type = newEvent.Type, RoomId = newEvent.RoomId ?? throw new InvalidOperationException("RoomId is null"), StateKey = newEvent.StateKey! }; if (ignoreListContent.IgnoredUsers.TryGetValue(content.Entity, out var existingRule)) { if (existingRule.AdditionalData?.ContainsKey(MadsIgnoreMetadataContent.EventId) ?? false) { var existingMetadata = existingRule.GetAdditionalData(MadsIgnoreMetadataContent.EventId); existingMetadata.Policies.Add(policyEventReference); } else { existingRule.AdditionalData ??= new(); existingRule.AdditionalData.Add(MadsIgnoreMetadataContent.EventId, new MadsIgnoreMetadataContent { WasUserAdded = true, Policies = [policyEventReference] }); } } else { ignoreListContent.IgnoredUsers[content.Entity] = new() { AdditionalData = new() { [MadsIgnoreMetadataContent.EventId] = new MadsIgnoreMetadataContent { WasUserAdded = false, Policies = [policyEventReference] } } }; } } foreach (var (previousEvent, newEvent) in updates.RemovedPolicies) { if (previousEvent.Type != UserPolicyRuleEventContent.EventId) continue; var previousContent = previousEvent.ContentAs(); if (previousContent.Entity is null || previousContent.IsGlobRule()) continue; if (previousContent.GetNormalizedRecommendation() != "m.ban") continue; var ignoreList = await homeserver.GetIgnoredUserListAsync(); if (ignoreList.IgnoredUsers.TryGetValue(previousContent.Entity, out var existingRule)) { if (existingRule.AdditionalData?.ContainsKey(MadsIgnoreMetadataContent.EventId) ?? false) { var existingMetadata = existingRule.GetAdditionalData(MadsIgnoreMetadataContent.EventId); existingMetadata.Policies.RemoveAll(x => x.Type == previousEvent.Type && x.RoomId == previousEvent.RoomId && x.StateKey == previousEvent.StateKey); if (!existingMetadata.WasUserAdded) ignoreList.IgnoredUsers.Remove(previousContent.Entity); } } } await homeserver.SetAccountDataAsync(IgnoredUserListEventContent.EventId, ignoreListContent); } private async Task FilterInvalidIgnoreListEntries() { var ignoreList = await homeserver.GetAccountDataOrNullAsync(IgnoredUserListEventContent.EventId); if (ignoreList != null) { ignoreList.IgnoredUsers.RemoveAll((id, ignoredUserData) => { if (ignoredUserData.AdditionalData is null) return false; if (!ignoredUserData.AdditionalData.ContainsKey(MadsIgnoreMetadataContent.EventId)) return false; var metadata = ignoredUserData.GetAdditionalData(MadsIgnoreMetadataContent.EventId)!; if (metadata.ContainsKey("policies")) { var policies = metadata["policies"]!.AsArray(); bool IsPolicyEntryValid(JsonNode? p) => p!["room_id"]?.GetValue() != null && p["type"]?.GetValue() != null && p["state_key"]?.GetValue() != null; if (policies.Any(x => !IsPolicyEntryValid(x))) { logger.LogWarning("Found invalid policy reference in ignore list, removing! {policy}", policies.Where(x => !IsPolicyEntryValid(x)).Select(x => x.ToJson(ignoreNull: true))); metadata["policies"] = new JsonArray(policies.Where(IsPolicyEntryValid).ToArray()); } } return metadata["was_user_added"]?.GetValue() is null or false; }); } return ignoreList; } private async Task CleanupInvalidIgnoreListEntries() { var ignoreList = await FilterInvalidIgnoreListEntries(); ignoreList.IgnoredUsers.RemoveAll((id, _) => !(id.StartsWith('@') && id.Contains(':'))); List idsToRemove = []; foreach (var (id, ignoredUserData) in ignoreList.IgnoredUsers) { if (ignoredUserData.AdditionalData is null) continue; if (!ignoredUserData.AdditionalData.ContainsKey(MadsIgnoreMetadataContent.EventId)) continue; try { var metadata = ignoredUserData.GetAdditionalData(MadsIgnoreMetadataContent.EventId)!; if (metadata.Policies.Count == 0 && !metadata.WasUserAdded) { idsToRemove.Add(id); } } catch (Exception e) { logger.LogError(e, "Failed to parse ignore list entry for {}", id); } } foreach (var id in idsToRemove) { ignoreList.IgnoredUsers.Remove(id); } await homeserver.SetAccountDataAsync(IgnoredUserListEventContent.EventId, ignoreList); } #endregion #region Feature: Report blocked invites #endregion #region Feature: Reject invites private Task CheckPoliciesAgainstInvite(RoomInviteContext invite) { logger.LogInformation("Checking policies against invite"); var sw = Stopwatch.StartNew(); // Technically not required, but helps with scaling against millions of policies Parallel.ForEach(policyStore.AllPolicies.Values, (policy, loopState, idx) => { if (CheckPolicyAgainstInvite(invite, policy) is not null) { logger.LogInformation("Found matching policy after {} iterations ({})", idx, sw.Elapsed); loopState.Break(); } }); return Task.CompletedTask; } private async Task CheckPolicyAgainstOutstandingInvites(StateEventResponse newEvent) { var tasks = roomInviteHandler.Invites .Select(invite => CheckPolicyAgainstInvite(invite, newEvent)) .Where(x => x is not null) .Cast() // from Task? .ToList(); await Task.WhenAll(tasks); } private Task? CheckPolicyAgainstInvite(RoomInviteContext invite, StateEventResponse policyEvent) { var policy = policyEvent.TypedContent as PolicyRuleEventContent ?? throw new InvalidOperationException("Policy is null"); if (policy.Recommendation != "m.ban") return null; var policyMatches = false; switch (policy) { case UserPolicyRuleEventContent userPolicy: policyMatches = userPolicy.EntityMatches(invite.MemberEvent.Sender!); break; case ServerPolicyRuleEventContent serverPolicy: policyMatches = serverPolicy.EntityMatches(invite.MemberEvent.Sender!); break; case RoomPolicyRuleEventContent roomPolicy: policyMatches = roomPolicy.EntityMatches(invite.RoomId); break; default: if (_logRoom is not null) _ = _logRoom.SendMessageEventAsync(new MessageBuilder().WithColoredBody("#FF0000", "Unknown policy type " + policy.GetType().FullName).Build()); break; } if (!policyMatches) return null; logger.LogWarning("[{}] Rejecting invite to {}, matching {} {}", homeserver.WhoAmI.UserId, invite.RoomId, policy.GetType().GetFriendlyName(), policy.ToJson(ignoreNull: true)); return Task.Run(async () => { if (_logRoom is not null) { string roomName = await invite.TryGetRoomNameAsync(); await roomInviteHandler.RejectInvite(invite, new MessageBuilder() .WithColoredBody("#FF0000", cb => cb.WithBody("Rejecting invite to ").WithMention(invite.RoomId, roomName) .WithBody($", matching {policy.GetType().GetFriendlyName().ToLowerInvariant()}.") .WithNewline()) .WithCollapsibleSection("Policy JSON", cb => cb.WithCodeBlock(policy.ToJson(ignoreNull: true), "json")) ); } }); } #endregion }