diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index f462a8b1c7..a8639d8f82 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -329,3 +329,110 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([local_message_event_id, remote_event_2.event_id])
+
+
+class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.state = self.hs.get_state_handler()
+ self.persistence = self.hs.get_storage().persistence
+ self.store = self.hs.get_datastore()
+
+ def test_remote_user_rooms_cache_invalidated(self):
+ """Test that if the server leaves a room the `get_rooms_for_user` cache
+ is invalidated for remote users.
+ """
+
+ # Set up a room with a local and remote user in it.
+ user_id = self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ room_id = self.helper.create_room_as(
+ "user", room_version=RoomVersions.V6.identifier, tok=token
+ )
+
+ body = self.helper.send(room_id, body="Test", tok=token)
+ local_message_event_id = body["event_id"]
+
+ # Fudge a join event for a remote user.
+ remote_user = "@user:other"
+ remote_event_1 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": remote_user,
+ "content": {"membership": Membership.JOIN},
+ "room_id": room_id,
+ "sender": remote_user,
+ "depth": 5,
+ "prev_events": [local_message_event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ context = self.get_success(self.state.compute_event_context(remote_event_1))
+ self.get_success(self.persistence.persist_event(remote_event_1, context))
+
+ # Call `get_rooms_for_user` to add the remote user to the cache
+ rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
+ self.assertEqual(set(rooms), {room_id})
+
+ # Now we have the local server leave the room, and check that calling
+ # `get_user_in_room` for the remote user no longer includes the room.
+ self.helper.leave(room_id, user_id, tok=token)
+
+ rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
+ self.assertEqual(set(rooms), set())
+
+ def test_room_remote_user_cache_invalidated(self):
+ """Test that if the server leaves a room the `get_users_in_room` cache
+ is invalidated for remote users.
+ """
+
+ # Set up a room with a local and remote user in it.
+ user_id = self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ room_id = self.helper.create_room_as(
+ "user", room_version=RoomVersions.V6.identifier, tok=token
+ )
+
+ body = self.helper.send(room_id, body="Test", tok=token)
+ local_message_event_id = body["event_id"]
+
+ # Fudge a join event for a remote user.
+ remote_user = "@user:other"
+ remote_event_1 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": remote_user,
+ "content": {"membership": Membership.JOIN},
+ "room_id": room_id,
+ "sender": remote_user,
+ "depth": 5,
+ "prev_events": [local_message_event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ context = self.get_success(self.state.compute_event_context(remote_event_1))
+ self.get_success(self.persistence.persist_event(remote_event_1, context))
+
+ # Call `get_users_in_room` to add the remote user to the cache
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual(set(users), {user_id, remote_user})
+
+ # Now we have the local server leave the room, and check that calling
+ # `get_user_in_room` for the remote user no longer includes the room.
+ self.helper.leave(room_id, user_id, tok=token)
+
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual(users, [])
|