summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11999.bugfix1
-rw-r--r--synapse/storage/databases/main/events.py27
-rw-r--r--tests/storage/test_events.py107
3 files changed, 124 insertions, 11 deletions
diff --git a/changelog.d/11999.bugfix b/changelog.d/11999.bugfix
new file mode 100644
index 0000000000..fd84095900
--- /dev/null
+++ b/changelog.d/11999.bugfix
@@ -0,0 +1 @@
+Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5246fccad5..a1d7a9b413 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -975,6 +975,17 @@ class PersistEventsStore:
             to_delete = delta_state.to_delete
             to_insert = delta_state.to_insert
 
+            # Figure out the changes of membership to invalidate the
+            # `get_rooms_for_user` cache.
+            # We find out which membership events we may have deleted
+            # and which we have added, then we invalidate the caches for all
+            # those users.
+            members_changed = {
+                state_key
+                for ev_type, state_key in itertools.chain(to_delete, to_insert)
+                if ev_type == EventTypes.Member
+            }
+
             if delta_state.no_longer_in_room:
                 # Server is no longer in the room so we delete the room from
                 # current_state_events, being careful we've already updated the
@@ -993,6 +1004,11 @@ class PersistEventsStore:
                 """
                 txn.execute(sql, (stream_id, self._instance_name, room_id))
 
+                # We also want to invalidate the membership caches for users
+                # that were in the room.
+                users_in_room = self.store.get_users_in_room_txn(txn, room_id)
+                members_changed.update(users_in_room)
+
                 self.db_pool.simple_delete_txn(
                     txn,
                     table="current_state_events",
@@ -1102,17 +1118,6 @@ class PersistEventsStore:
 
             # Invalidate the various caches
 
-            # Figure out the changes of membership to invalidate the
-            # `get_rooms_for_user` cache.
-            # We find out which membership events we may have deleted
-            # and which we have added, then we invalidate the caches for all
-            # those users.
-            members_changed = {
-                state_key
-                for ev_type, state_key in itertools.chain(to_delete, to_insert)
-                if ev_type == EventTypes.Member
-            }
-
             for member in members_changed:
                 txn.call_after(
                     self.store.get_rooms_for_user_with_stream_ordering.invalidate,
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index f462a8b1c7..a8639d8f82 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -329,3 +329,110 @@ class ExtremPruneTestCase(HomeserverTestCase):
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([local_message_event_id, remote_event_2.event_id])
+
+
+class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.state = self.hs.get_state_handler()
+        self.persistence = self.hs.get_storage().persistence
+        self.store = self.hs.get_datastore()
+
+    def test_remote_user_rooms_cache_invalidated(self):
+        """Test that if the server leaves a room the `get_rooms_for_user` cache
+        is invalidated for remote users.
+        """
+
+        # Set up a room with a local and remote user in it.
+        user_id = self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        room_id = self.helper.create_room_as(
+            "user", room_version=RoomVersions.V6.identifier, tok=token
+        )
+
+        body = self.helper.send(room_id, body="Test", tok=token)
+        local_message_event_id = body["event_id"]
+
+        # Fudge a join event for a remote user.
+        remote_user = "@user:other"
+        remote_event_1 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": remote_user,
+                "content": {"membership": Membership.JOIN},
+                "room_id": room_id,
+                "sender": remote_user,
+                "depth": 5,
+                "prev_events": [local_message_event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        context = self.get_success(self.state.compute_event_context(remote_event_1))
+        self.get_success(self.persistence.persist_event(remote_event_1, context))
+
+        # Call `get_rooms_for_user` to add the remote user to the cache
+        rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
+        self.assertEqual(set(rooms), {room_id})
+
+        # Now we have the local server leave the room, and check that calling
+        # `get_user_in_room` for the remote user no longer includes the room.
+        self.helper.leave(room_id, user_id, tok=token)
+
+        rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
+        self.assertEqual(set(rooms), set())
+
+    def test_room_remote_user_cache_invalidated(self):
+        """Test that if the server leaves a room the `get_users_in_room` cache
+        is invalidated for remote users.
+        """
+
+        # Set up a room with a local and remote user in it.
+        user_id = self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        room_id = self.helper.create_room_as(
+            "user", room_version=RoomVersions.V6.identifier, tok=token
+        )
+
+        body = self.helper.send(room_id, body="Test", tok=token)
+        local_message_event_id = body["event_id"]
+
+        # Fudge a join event for a remote user.
+        remote_user = "@user:other"
+        remote_event_1 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": remote_user,
+                "content": {"membership": Membership.JOIN},
+                "room_id": room_id,
+                "sender": remote_user,
+                "depth": 5,
+                "prev_events": [local_message_event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        context = self.get_success(self.state.compute_event_context(remote_event_1))
+        self.get_success(self.persistence.persist_event(remote_event_1, context))
+
+        # Call `get_users_in_room` to add the remote user to the cache
+        users = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual(set(users), {user_id, remote_user})
+
+        # Now we have the local server leave the room, and check that calling
+        # `get_user_in_room` for the remote user no longer includes the room.
+        self.helper.leave(room_id, user_id, tok=token)
+
+        users = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual(users, [])