diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 55a4f95ef3..9659a4a355 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -30,6 +30,7 @@ from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict, create_requester
from synapse.util import Clock
+from synapse.util.task_scheduler import TaskScheduler
from tests import unittest
from tests.unittest import override_config
@@ -49,6 +50,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
assert isinstance(handler, DeviceHandler)
self.handler = handler
self.store = hs.get_datastores().main
+ self.device_message_handler = hs.get_device_message_handler()
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -211,6 +213,51 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(res)
+ def test_delete_device_and_big_device_inbox(self) -> None:
+ """Check that deleting a big device inbox is staged and batched asynchronously."""
+ DEVICE_ID = "abc"
+ sender = "@sender:" + self.hs.hostname
+ receiver = "@receiver:" + self.hs.hostname
+ self._record_user(sender, DEVICE_ID, DEVICE_ID)
+ self._record_user(receiver, DEVICE_ID, DEVICE_ID)
+
+ # queue a bunch of messages in the inbox
+ requester = create_requester(sender, device_id=DEVICE_ID)
+ for i in range(0, DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
+ self.get_success(
+ self.device_message_handler.send_device_message(
+ requester, "message_type", {receiver: {"*": {"val": i}}}
+ )
+ )
+
+ # delete the device
+ self.get_success(self.handler.delete_devices(receiver, [DEVICE_ID]))
+
+ # messages should be deleted up to DEVICE_MSGS_DELETE_BATCH_LIMIT straight away
+ res = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="device_inbox",
+ keyvalues={"user_id": receiver},
+ retcols=("user_id", "device_id", "stream_id"),
+ desc="get_device_id_from_device_inbox",
+ )
+ )
+ self.assertEqual(10, len(res))
+
+ # wait for the task scheduler to do a second delete pass
+ self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)
+
+ # remaining messages should now be deleted
+ res = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="device_inbox",
+ keyvalues={"user_id": receiver},
+ retcols=("user_id", "device_id", "stream_id"),
+ desc="get_device_id_from_device_inbox",
+ )
+ )
+ self.assertEqual(0, len(res))
+
def test_update_device(self) -> None:
self._record_users()
|