summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2022-09-23 00:16:50 -0500
committerEric Eastwood <erice@element.io>2022-09-23 00:16:50 -0500
commit44e97465e9ff003b4b7be79ba169430d65b0edc7 (patch)
treec5b70013830557451c706459a9f7f10e3757e055 /tests
parentScratch try different orders just to see how the tests pass differently (diff)
parentAdd test to ensure the safety works (diff)
downloadsynapse-44e97465e9ff003b4b7be79ba169430d65b0edc7.tar.xz
Merge branch 'madlittlemods/13856-fix-have-seen-events-not-being-invalidated' into maddlittlemods/msc2716-many-batches-optimization
Conflicts:
	tests/storage/databases/main/test_events_worker.py
Diffstat (limited to 'tests')
-rw-r--r--tests/push/test_email.py4
-rw-r--r--tests/push/test_http.py193
-rw-r--r--tests/replication/test_module_cache_invalidation.py79
-rw-r--r--tests/replication/test_pusher_shard.py2
-rw-r--r--tests/rest/admin/test_user.py2
-rw-r--r--tests/rest/client/test_login_token_request.py132
-rw-r--r--tests/rest/client/test_relations.py29
-rw-r--r--tests/storage/databases/main/test_events_worker.py32
-rw-r--r--tests/unittest.py36
-rw-r--r--tests/util/caches/test_descriptors.py33
10 files changed, 497 insertions, 45 deletions
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 7a3b0d6755..fd14568f55 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
 
         self.pusher = self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=self.user_id,
                 access_token=self.token_id,
                 kind="email",
@@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase):
         """
         with self.assertRaises(SynapseError) as cm:
             self.get_success_or_raise(
-                self.hs.get_pusherpool().add_pusher(
+                self.hs.get_pusherpool().add_or_update_pusher(
                     user_id=self.user_id,
                     access_token=self.token_id,
                     kind="email",
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index d9c68cdd2d..b383b8401f 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -19,9 +19,10 @@ from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.logging.context import make_deferred_yieldable
-from synapse.push import PusherConfigException
-from synapse.rest.client import login, push_rule, receipts, room
+from synapse.push import PusherConfig, PusherConfigException
+from synapse.rest.client import login, push_rule, pusher, receipts, room
 from synapse.server import HomeServer
+from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import JsonDict
 from synapse.util import Clock
 
@@ -35,6 +36,7 @@ class HTTPPusherTests(HomeserverTestCase):
         login.register_servlets,
         receipts.register_servlets,
         push_rule.register_servlets,
+        pusher.register_servlets,
     ]
     user_id = True
     hijack_auth = False
@@ -74,7 +76,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         def test_data(data: Optional[JsonDict]) -> None:
             self.get_failure(
-                self.hs.get_pusherpool().add_pusher(
+                self.hs.get_pusherpool().add_or_update_pusher(
                     user_id=user_id,
                     access_token=token_id,
                     kind="http",
@@ -119,7 +121,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -235,7 +237,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -355,7 +357,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -441,7 +443,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -518,7 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -624,7 +626,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -728,18 +730,38 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         self.assertEqual(channel.code, 200, channel.json_body)
 
-    def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
+    def _make_user_with_pusher(
+        self, username: str, enabled: bool = True
+    ) -> Tuple[str, str]:
+        """Registers a user and creates a pusher for them.
+
+        Args:
+            username: the localpart of the new user's Matrix ID.
+            enabled: whether to create the pusher in an enabled or disabled state.
+        """
         user_id = self.register_user(username, "pass")
         access_token = self.login(username, "pass")
 
         # Register the pusher
+        self._set_pusher(user_id, access_token, enabled)
+
+        return user_id, access_token
+
+    def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None:
+        """Creates or updates the pusher for the given user.
+
+        Args:
+            user_id: the user's Matrix ID.
+            access_token: the access token associated with the pusher.
+            enabled: whether to enable or disable the pusher.
+        """
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -749,11 +771,11 @@ class HTTPPusherTests(HomeserverTestCase):
                 pushkey="a@example.com",
                 lang=None,
                 data={"url": "http://example.com/_matrix/push/v1/notify"},
+                enabled=enabled,
+                device_id=user_tuple.device_id,
             )
         )
 
-        return user_id, access_token
-
     def test_dont_notify_rule_overrides_message(self) -> None:
         """
         The override push rule will suppress notification
@@ -791,3 +813,148 @@ class HTTPPusherTests(HomeserverTestCase):
         # The user sends a message back (sends a notification)
         self.helper.send(room, body="Hello", tok=access_token)
         self.assertEqual(len(self.push_attempts), 1)
+
+    @override_config({"experimental_features": {"msc3881_enabled": True}})
+    def test_disable(self) -> None:
+        """Tests that disabling a pusher means it's not pushed to anymore."""
+        user_id, access_token = self._make_user_with_pusher("user")
+        other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+        room = self.helper.create_room_as(user_id, tok=access_token)
+        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+        # Send a message and check that it generated a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Disable the pusher.
+        self._set_pusher(user_id, access_token, enabled=False)
+
+        # Send another message and check that it did not generate a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Get the pushers for the user and check that it is marked as disabled.
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+
+        enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
+        self.assertFalse(enabled)
+        self.assertTrue(isinstance(enabled, bool))
+
+    @override_config({"experimental_features": {"msc3881_enabled": True}})
+    def test_enable(self) -> None:
+        """Tests that enabling a disabled pusher means it gets pushed to."""
+        # Create the user with the pusher already disabled.
+        user_id, access_token = self._make_user_with_pusher("user", enabled=False)
+        other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+        room = self.helper.create_room_as(user_id, tok=access_token)
+        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+        # Send a message and check that it did not generate a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 0)
+
+        # Enable the pusher.
+        self._set_pusher(user_id, access_token, enabled=True)
+
+        # Send another message and check that it did generate a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Get the pushers for the user and check that it is marked as enabled.
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+
+        enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
+        self.assertTrue(enabled)
+        self.assertTrue(isinstance(enabled, bool))
+
+    @override_config({"experimental_features": {"msc3881_enabled": True}})
+    def test_null_enabled(self) -> None:
+        """Tests that a pusher that has an 'enabled' column set to NULL (eg pushers
+        created before the column was introduced) is considered enabled.
+        """
+        # We intentionally set 'enabled' to None so that it's stored as NULL in the
+        # database.
+        user_id, access_token = self._make_user_with_pusher("user", enabled=None)  # type: ignore[arg-type]
+
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+        self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
+
+    def test_update_different_device_access_token_device_id(self) -> None:
+        """Tests that if we create a pusher from one device, the update it from another
+        device, the access token and device ID associated with the pusher stays the
+        same.
+        """
+        # Create a user with a pusher.
+        user_id, access_token = self._make_user_with_pusher("user")
+
+        # Get the token ID for the current access token, since that's what we store in
+        # the pushers table. Also get the device ID from it.
+        user_tuple = self.get_success(
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
+        )
+        token_id = user_tuple.token_id
+        device_id = user_tuple.device_id
+
+        # Generate a new access token, and update the pusher with it.
+        new_token = self.login("user", "pass")
+        self._set_pusher(user_id, new_token, enabled=False)
+
+        # Get the current list of pushers for the user.
+        ret = self.get_success(
+            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+        )
+        pushers: List[PusherConfig] = list(ret)
+
+        # Check that we still have one pusher, and that the access token and device ID
+        # associated with it didn't change.
+        self.assertEqual(len(pushers), 1)
+        self.assertEqual(pushers[0].access_token, token_id)
+        self.assertEqual(pushers[0].device_id, device_id)
+
+    @override_config({"experimental_features": {"msc3881_enabled": True}})
+    def test_device_id(self) -> None:
+        """Tests that a pusher created with a given device ID shows that device ID in
+        GET /pushers requests.
+        """
+        self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # We create the pusher with an HTTP request rather than with
+        # _make_user_with_pusher so that we can test the device ID is correctly set when
+        # creating a pusher via an API call.
+        self.make_request(
+            method="POST",
+            path="/pushers/set",
+            content={
+                "kind": "http",
+                "app_id": "m.http",
+                "app_display_name": "HTTP Push Notifications",
+                "device_display_name": "pushy push",
+                "pushkey": "a@example.com",
+                "lang": "en",
+                "data": {"url": "http://example.com/_matrix/push/v1/notify"},
+            },
+            access_token=access_token,
+        )
+
+        # Look up the user info for the access token so we can compare the device ID.
+        lookup_result: TokenLookupResult = self.get_success(
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
+        )
+
+        # Get the user's devices and check it has the correct device ID.
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+        self.assertEqual(
+            channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
+            lookup_result.device_id,
+        )
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
new file mode 100644
index 0000000000..b93cae67d3
--- /dev/null
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -0,0 +1,79 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import synapse
+from synapse.module_api import cached
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+FIRST_VALUE = "one"
+SECOND_VALUE = "two"
+
+KEY = "mykey"
+
+
+class TestCache:
+    current_value = FIRST_VALUE
+
+    @cached()
+    async def cached_function(self, user_id: str) -> str:
+        return self.current_value
+
+
+class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+    ]
+
+    def test_module_cache_full_invalidation(self):
+        main_cache = TestCache()
+        self.hs.get_module_api().register_cached_function(main_cache.cached_function)
+
+        worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+
+        worker_cache = TestCache()
+        worker_hs.get_module_api().register_cached_function(
+            worker_cache.cached_function
+        )
+
+        self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+        self.assertEqual(
+            FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
+        )
+
+        main_cache.current_value = SECOND_VALUE
+        worker_cache.current_value = SECOND_VALUE
+        # No invalidation yet, should return the cached value on both the main process and the worker
+        self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+        self.assertEqual(
+            FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
+        )
+
+        # Full invalidation on the main process, should be replicated on the worker that
+        # should returned the updated value too
+        self.get_success(
+            self.hs.get_module_api().invalidate_cache(
+                main_cache.cached_function, (KEY,)
+            )
+        )
+
+        self.assertEqual(
+            SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
+        )
+        self.assertEqual(
+            SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
+        )
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 8f4f6688ce..59fea93e49 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         token_id = user_dict.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 9f536ceeb3..1847e6ad6b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=self.other_user,
                 access_token=token_id,
                 kind="http",
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
new file mode 100644
index 0000000000..d5bb16c98d
--- /dev/null
+++ b/tests/rest/client/test_login_token_request.py
@@ -0,0 +1,132 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin
+from synapse.rest.client import login, login_token_request
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        login.register_servlets,
+        admin.register_servlets,
+        login_token_request.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.hs = self.setup_test_homeserver()
+        self.hs.config.registration.enable_registration = True
+        self.hs.config.registration.registrations_require_3pid = []
+        self.hs.config.registration.auto_join_rooms = []
+        self.hs.config.captcha.enable_registration_captcha = False
+
+        return self.hs
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.user = "user123"
+        self.password = "password"
+
+    def test_disabled(self) -> None:
+        channel = self.make_request("POST", "/login/token", {}, access_token=None)
+        self.assertEqual(channel.code, 400)
+
+        self.register_user(self.user, self.password)
+        token = self.login(self.user, self.password)
+
+        channel = self.make_request("POST", "/login/token", {}, access_token=token)
+        self.assertEqual(channel.code, 400)
+
+    @override_config({"experimental_features": {"msc3882_enabled": True}})
+    def test_require_auth(self) -> None:
+        channel = self.make_request("POST", "/login/token", {}, access_token=None)
+        self.assertEqual(channel.code, 401)
+
+    @override_config({"experimental_features": {"msc3882_enabled": True}})
+    def test_uia_on(self) -> None:
+        user_id = self.register_user(self.user, self.password)
+        token = self.login(self.user, self.password)
+
+        channel = self.make_request("POST", "/login/token", {}, access_token=token)
+        self.assertEqual(channel.code, 401)
+        self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+        session = channel.json_body["session"]
+
+        uia = {
+            "auth": {
+                "type": "m.login.password",
+                "identifier": {"type": "m.id.user", "user": self.user},
+                "password": self.password,
+                "session": session,
+            },
+        }
+
+        channel = self.make_request("POST", "/login/token", uia, access_token=token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["expires_in"], 300)
+
+        login_token = channel.json_body["login_token"]
+
+        channel = self.make_request(
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.json_body["user_id"], user_id)
+
+    @override_config(
+        {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}}
+    )
+    def test_uia_off(self) -> None:
+        user_id = self.register_user(self.user, self.password)
+        token = self.login(self.user, self.password)
+
+        channel = self.make_request("POST", "/login/token", {}, access_token=token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["expires_in"], 300)
+
+        login_token = channel.json_body["login_token"]
+
+        channel = self.make_request(
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.json_body["user_id"], user_id)
+
+    @override_config(
+        {
+            "experimental_features": {
+                "msc3882_enabled": True,
+                "msc3882_ui_auth": False,
+                "msc3882_token_timeout": "15s",
+            }
+        }
+    )
+    def test_expires_in(self) -> None:
+        self.register_user(self.user, self.password)
+        token = self.login(self.user, self.password)
+
+        channel = self.make_request("POST", "/login/token", {}, access_token=token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["expires_in"], 15)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 651f4f415d..d33e34d829 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -788,6 +788,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
             channel.json_body["chunk"][0],
         )
 
+    @unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
     def test_repeated_paginate_relations(self) -> None:
         """Test that if we paginate using a limit and tokens then we get the
         expected events.
@@ -809,7 +810,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
 
             channel = self.make_request(
                 "GET",
-                f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+                f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=3{from_token}",
                 access_token=self.user_token,
             )
             self.assertEqual(200, channel.code, channel.json_body)
@@ -827,6 +828,32 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
         found_event_ids.reverse()
         self.assertEqual(found_event_ids, expected_event_ids)
 
+        # Test forward pagination.
+        prev_token = ""
+        found_event_ids = []
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + prev_token
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?org.matrix.msc3715.dir=f&limit=3{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEqual(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        self.assertEqual(found_event_ids, expected_event_ids)
+
     def test_pagination_from_sync_and_messages(self) -> None:
         """Pagination tokens from /sync and /messages can be used to paginate /relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 158ad1f439..32a798d74b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -103,6 +103,11 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
             self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
 
     def test_persisting_event_invalidates_cache(self):
+        """
+        Test to make sure that the `have_seen_event` cache
+        is invalidated after we persist an event and returns
+        the updated value.
+        """
         event, event_context = self.get_success(
             create_event(
                 self.hs,
@@ -145,6 +150,33 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
             # That should result in a single db query to lookup
             self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
 
+    def test_invalidate_cache_by_room_id(self):
+        """
+        Test to make sure that all events associated with the given `(room_id,)`
+        are invalidated in the `have_seen_event` cache.
+        """
+        with LoggingContext(name="test") as ctx:
+            # Prime the cache with some values
+            res = self.get_success(
+                self.store.have_seen_events(self.room_id, self.event_ids)
+            )
+            self.assertEqual(res, set(self.event_ids))
+
+            # That should result in a single db query to lookup
+            self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
+        # Clear the cache with any events associated with the `room_id`
+        self.store.have_seen_event.invalidate((self.room_id,))
+
+        with LoggingContext(name="test") as ctx:
+            res = self.get_success(
+                self.store.have_seen_events(self.room_id, self.event_ids)
+            )
+            self.assertEqual(res, set(self.event_ids))
+
+            # Since we cleared the cache, it should result in another db query to lookup
+            self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
 
 class EventCacheTestCase(unittest.HomeserverTestCase):
     """Test that the various layers of event cache works."""
diff --git a/tests/unittest.py b/tests/unittest.py
index 975b0a23a7..00cb023198 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -300,47 +300,31 @@ class HomeserverTestCase(TestCase):
         if hasattr(self, "user_id"):
             if self.hijack_auth:
                 assert self.helper.auth_user_id is not None
+                token = "some_fake_token"
 
                 # We need a valid token ID to satisfy foreign key constraints.
                 token_id = self.get_success(
                     self.hs.get_datastores().main.add_access_token_to_user(
                         self.helper.auth_user_id,
-                        "some_fake_token",
+                        token,
                         None,
                         None,
                     )
                 )
 
-                async def get_user_by_access_token(
-                    token: Optional[str] = None, allow_guest: bool = False
-                ) -> JsonDict:
-                    assert self.helper.auth_user_id is not None
-                    return {
-                        "user": UserID.from_string(self.helper.auth_user_id),
-                        "token_id": token_id,
-                        "is_guest": False,
-                    }
-
-                async def get_user_by_req(
-                    request: SynapseRequest,
-                    allow_guest: bool = False,
-                    allow_expired: bool = False,
-                ) -> Requester:
+                # This has to be a function and not just a Mock, because
+                # `self.helper.auth_user_id` is temporarily reassigned in some tests
+                async def get_requester(*args, **kwargs) -> Requester:
                     assert self.helper.auth_user_id is not None
                     return create_requester(
-                        UserID.from_string(self.helper.auth_user_id),
-                        token_id,
-                        False,
-                        False,
-                        None,
+                        user_id=UserID.from_string(self.helper.auth_user_id),
+                        access_token_id=token_id,
                     )
 
                 # Type ignore: mypy doesn't like us assigning to methods.
-                self.hs.get_auth().get_user_by_req = get_user_by_req  # type: ignore[assignment]
-                self.hs.get_auth().get_user_by_access_token = get_user_by_access_token  # type: ignore[assignment]
-                self.hs.get_auth().get_access_token_from_request = Mock(  # type: ignore[assignment]
-                    return_value="1234"
-                )
+                self.hs.get_auth().get_user_by_req = get_requester  # type: ignore[assignment]
+                self.hs.get_auth().get_user_by_access_token = get_requester  # type: ignore[assignment]
+                self.hs.get_auth().get_access_token_from_request = Mock(return_value=token)  # type: ignore[assignment]
 
         if self.needs_threadpool:
             self.reactor.threadpool = ThreadPool()  # type: ignore[assignment]
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 48e616ac74..90861fe522 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Set
+from typing import Iterable, Set, Tuple
 from unittest import mock
 
 from twisted.internet import defer, reactor
@@ -1008,3 +1008,34 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             obj.inner_context_was_finished, "Tried to restart a finished logcontext"
         )
         self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+    def test_num_args_mismatch(self):
+        """
+        Make sure someone does not accidentally use @cachedList on a method with
+        a mismatch in the number args to the underlying single cache method.
+        """
+
+        class Cls:
+            @descriptors.cached(tree=True)
+            def fn(self, room_id, event_id):
+                pass
+
+            # This is wrong ❌. `@cachedList` expects to be given the same number
+            # of arguments as the underlying cached function, just with one of
+            # the arguments being an iterable
+            @descriptors.cachedList(cached_method_name="fn", list_name="keys")
+            def list_fn(self, keys: Iterable[Tuple[str, str]]):
+                pass
+
+            # Corrected syntax ✅
+            #
+            # @cachedList(cached_method_name="fn", list_name="event_ids")
+            # async def list_fn(
+            #     self, room_id: str, event_ids: Collection[str],
+            # )
+
+        obj = Cls()
+
+        # Make sure this raises an error about the arg mismatch
+        with self.assertRaises(Exception):
+            obj.list_fn([("foo", "bar")])