diff --git a/PolicyEngine.cs b/PolicyEngine.cs
index 0d0ed65..7556fc5 100644
--- a/PolicyEngine.cs
+++ b/PolicyEngine.cs
@@ -1,4 +1,5 @@
using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography;
using System.Text.Json;
using System.Text.RegularExpressions;
@@ -12,13 +13,15 @@ using LibMatrix.RoomTypes;
using LibMatrix.Services;
using Microsoft.Extensions.Logging;
using ModerationBot.AccountData;
+using ModerationBot.Services;
using ModerationBot.StateEventTypes.Policies;
using ModerationBot.StateEventTypes.Policies.Implementations;
namespace ModerationBot;
-public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationBot> logger, ModerationBotConfiguration configuration, HomeserverResolverService hsResolver) {
+public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationBot> logger, ModerationBotConfiguration configuration, HomeserverResolverService hsResolver, ModerationBotRoomProvider roomProvider) {
private Dictionary<string, PolicyList> PolicyListAccountData { get; set; } = new();
+
// ReSharper disable once MemberCanBePrivate.Global
public List<PolicyList> ActivePolicyLists { get; set; } = new();
public List<BasePolicy> ActivePolicies { get; set; } = new();
@@ -28,11 +31,10 @@ public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationB
public async Task ReloadActivePolicyLists() {
var sw = Stopwatch.StartNew();
-
- var botData = await hs.GetAccountDataAsync<BotData>("gay.rory.moderation_bot_data");
- _logRoom ??= hs.GetRoom(botData.LogRoom ?? botData.ControlRoom);
- _controlRoom ??= hs.GetRoom(botData.ControlRoom);
-
+
+ _logRoom = await roomProvider.GetLogRoomAsync();
+ _controlRoom = await roomProvider.GetControlRoomAsync();
+
await _controlRoom?.SendMessageEventAsync(MessageFormatter.FormatSuccess("Reloading policy lists!"))!;
await _logRoom?.SendMessageEventAsync(MessageFormatter.FormatSuccess("Reloading policy lists!"))!;
@@ -46,12 +48,12 @@ public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationB
if (e is not { ErrorCode: "M_NOT_FOUND" }) throw;
}
- if (!PolicyListAccountData.ContainsKey(botData.DefaultPolicyRoom!)) {
- PolicyListAccountData.Add(botData.DefaultPolicyRoom!, new PolicyList() {
- Trusted = true
- });
- await hs.SetAccountDataAsync("gay.rory.moderation_bot.policy_lists", PolicyListAccountData);
- }
+ // if (!PolicyListAccountData.ContainsKey(botData.DefaultPolicyRoom!)) {
+ // PolicyListAccountData.Add(botData.DefaultPolicyRoom!, new PolicyList() {
+ // Trusted = true
+ // });
+ // await hs.SetAccountDataAsync("gay.rory.moderation_bot.policy_lists", PolicyListAccountData);
+ // }
var loadTasks = new List<Task<PolicyList>>();
foreach (var (roomId, policyList) in PolicyListAccountData) {
@@ -62,11 +64,11 @@ public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationB
await foreach (var policyList in loadTasks.ToAsyncEnumerable()) {
policyLists.Add(policyList);
- if (false || policyList.Policies.Count >= 256 || policyLists.Count == PolicyListAccountData.Count) {
+ if (true || policyList.Policies.Count >= 256 || policyLists.Count == PolicyListAccountData.Count) {
var progressMsgContent = MessageFormatter.FormatSuccess($"{policyLists.Count}/{PolicyListAccountData.Count} policy lists loaded, " +
$"{policyLists.Sum(x => x.Policies.Count)} policies total, {sw.Elapsed} elapsed.")
.SetReplaceRelation<RoomMessageEventContent>(progressMessage.EventId);
-
+
await _logRoom?.SendMessageEventAsync(progressMsgContent);
}
}
@@ -99,7 +101,6 @@ public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationB
return policyList;
}
-
public async Task ReloadActivePolicyListById(string roomId) {
if (!ActivePolicyLists.Any(x => x.Room.RoomId == roomId)) return;
await LoadPolicyListAsync(hs.GetRoom(roomId), ActivePolicyLists.Single(x => x.Room.RoomId == roomId));
@@ -159,7 +160,7 @@ public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationB
return matchingPolicies;
}
- #region Policy matching
+#region Policy matching
private async Task<List<BasePolicy>> CheckMessageContent(StateEventResponse @event) {
var matchedRules = new List<BasePolicy>();
@@ -167,11 +168,11 @@ public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationB
if (ActivePoliciesByType.TryGetValue(nameof(MessagePolicyContainsText), out var messageContainsPolicies))
foreach (var policy in messageContainsPolicies) {
- if ((@msgContent?.Body?.ToLowerInvariant().Contains(policy.Entity.ToLowerInvariant()) ?? false) || (@msgContent?.FormattedBody?.ToLowerInvariant().Contains(policy.Entity.ToLowerInvariant()) ?? false))
+ if ((@msgContent?.Body?.ToLowerInvariant().Contains(policy.Entity.ToLowerInvariant()) ?? false) ||
+ (@msgContent?.FormattedBody?.ToLowerInvariant().Contains(policy.Entity.ToLowerInvariant()) ?? false))
matchedRules.Add(policy);
}
-
return matchedRules;
}
@@ -243,11 +244,11 @@ public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationB
return matchedRules;
}
- #endregion
+#endregion
- #region Internal code
+#region Internal code
- #region Summarisation
+#region Summarisation
private static (string Raw, string Html) SummariseStateTypeCounts(IList<StateEventResponse> states) {
string raw = "Count | State type | Mapped type", html = "<table><tr><th>Count</th><th>State type</th><th>Mapped type</th></tr>";
@@ -261,8 +262,7 @@ public class PolicyEngine(AuthenticatedHomeserverGeneric hs, ILogger<ModerationB
return (raw, html);
}
- #endregion
-
- #endregion
+#endregion
-}
+#endregion
+}
\ No newline at end of file
|