diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 6982c7291a..f9ae50f40a 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -23,6 +23,7 @@ import hashlib
import hmac
import json
import os
+import time
import urllib.parse
from binascii import unhexlify
from http import HTTPStatus
@@ -56,6 +57,7 @@ from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import SMALL_PNG
from tests.unittest import override_config
@@ -5127,7 +5129,6 @@ class UserRedactionTestCase(unittest.HomeserverTestCase):
"""
Test that request to redact events in all rooms user is member of is successful
"""
-
# join rooms, send some messages
originals = []
for rm in [self.rm1, self.rm2, self.rm3]:
@@ -5404,3 +5405,98 @@ class UserRedactionTestCase(unittest.HomeserverTestCase):
matches.append((event_id, event))
# we redacted 6 messages
self.assertEqual(len(matches), 6)
+
+
+class UserRedactionBackgroundTaskTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ # create rooms - room versions 11+ store the `redacts` key in content while
+ # earlier ones don't so we use a mix of room versions
+ self.rm1 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="7"
+ )
+ self.rm2 = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.rm3 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="11"
+ )
+
+ @override_config({"run_background_tasks_on": "worker1"})
+ def test_redact_messages_all_rooms(self) -> None:
+ """
+ Test that redact task successfully runs when `run_background_tasks_on` is specified
+ """
+ self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker1",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+
+ # join rooms, send some messages
+ original_event_ids = set()
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ original_event_ids.add(join["event_id"])
+ for i in range(15):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok, expect_code=200
+ )
+ original_event_ids.add(res["event_id"])
+
+ # redact all events in all rooms
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel.json_body.get("redact_id")
+
+ timeout_s = 10
+ start_time = time.time()
+ redact_result = ""
+ while redact_result != "complete":
+ if start_time + timeout_s < time.time():
+ self.fail("Timed out waiting for redactions.")
+
+ channel2 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ redact_result = channel2.json_body["status"]
+ if redact_result == "failed":
+ self.fail("Redaction task failed.")
+
+ redaction_ids = set()
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ if event["type"] == "m.room.redaction":
+ redaction_ids.add(event["redacts"])
+
+ self.assertIncludes(redaction_ids, original_event_ids, exact=True)
|