summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6801.bugfix1
-rw-r--r--synapse/storage/data_stores/main/roommember.py37
-rw-r--r--synapse/storage/persist_events.py51
3 files changed, 83 insertions, 6 deletions
diff --git a/changelog.d/6801.bugfix b/changelog.d/6801.bugfix
new file mode 100644
index 0000000000..f401fa5d69
--- /dev/null
+++ b/changelog.d/6801.bugfix
@@ -0,0 +1 @@
+Fix bug where Synapse didn't invalidate cache of remote users' devices when Synapse left a room.
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 9acef7c950..042289f0e0 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import Iterable, List
+from typing import Iterable, List, Set
 
 from six import iteritems, itervalues
 
@@ -40,7 +40,7 @@ from synapse.storage.roommember import (
     ProfileInfo,
     RoomsForUser,
 )
-from synapse.types import get_domain_from_id
+from synapse.types import Collection, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@@ -439,6 +439,39 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return results
 
+    async def get_users_server_still_shares_room_with(
+        self, user_ids: Collection[str]
+    ) -> Set[str]:
+        """Given a list of users return the set that the server still share a
+        room with.
+        """
+
+        if not user_ids:
+            return set()
+
+        def _get_users_server_still_shares_room_with_txn(txn):
+            sql = """
+                SELECT state_key FROM current_state_events
+                WHERE
+                    type = 'm.room.member'
+                    AND membership = 'join'
+                    AND %s
+                GROUP BY state_key
+            """
+
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "state_key", user_ids
+            )
+
+            txn.execute(sql % (clause,), args)
+
+            return set(row[0] for row in txn)
+
+        return await self.db.runInteraction(
+            "get_users_server_still_shares_room_with",
+            _get_users_server_still_shares_room_with_txn,
+        )
+
     @defer.inlineCallbacks
     def get_rooms_for_user(self, user_id, on_invalidate=None):
         """Returns a set of room_ids the user is currently joined to.
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index d060c8b992..86166fd4c1 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -18,7 +18,7 @@
 import itertools
 import logging
 from collections import deque, namedtuple
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 from six import iteritems
 from six.moves import range
@@ -318,6 +318,11 @@ class EventsPersistenceStorage(object):
             # room
             state_delta_for_room = {}
 
+            # Set of remote users which were in rooms the server has left. We
+            # should check if we still share any rooms and if not we mark their
+            # device lists as stale.
+            potentially_left_users = set()  # type: Set[str]
+
             if not backfilled:
                 with Measure(self._clock, "_calculate_state_and_extrem"):
                     # Work out the new "current state" for each room.
@@ -421,7 +426,11 @@ class EventsPersistenceStorage(object):
                             # the room then we delete the current state and
                             # extremities.
                             is_still_joined = await self._is_server_still_joined(
-                                room_id, ev_ctx_rm, delta, current_state
+                                room_id,
+                                ev_ctx_rm,
+                                delta,
+                                current_state,
+                                potentially_left_users,
                             )
                             if not is_still_joined:
                                 logger.info("Server no longer in room %s", room_id)
@@ -444,6 +453,8 @@ class EventsPersistenceStorage(object):
                 backfilled=backfilled,
             )
 
+            await self._handle_potentially_left_users(potentially_left_users)
+
     async def _calculate_new_extremities(
         self,
         room_id: str,
@@ -688,6 +699,7 @@ class EventsPersistenceStorage(object):
         ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
         delta: DeltaState,
         current_state: Optional[StateMap[str]],
+        potentially_left_users: Set[str],
     ) -> bool:
         """Check if the server will still be joined after the given events have
         been persised.
@@ -699,6 +711,9 @@ class EventsPersistenceStorage(object):
                 and what the new current state will be.
             current_state: The new current state if it already been calculated,
                 otherwise None.
+            potentially_left_users: If the server has left the room, then joined
+                remote users will be added to this set to indicate that the
+                server may no longer be sharing a room with them.
         """
 
         if not any(
@@ -741,5 +756,33 @@ class EventsPersistenceStorage(object):
         is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
         if is_still_joined:
             return True
-        else:
-            return False
+
+        # The server will leave the room, so we go and find out which remote
+        # users will still be joined when we leave.
+        remote_event_ids = [
+            event_id
+            for (typ, state_key,), event_id in current_state.items()
+            if typ == EventTypes.Member and not self.is_mine_id(state_key)
+        ]
+        rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
+        potentially_left_users.update(
+            row["user_id"] for row in rows if row["membership"] == Membership.JOIN
+        )
+
+        return False
+
+    async def _handle_potentially_left_users(self, user_ids: Set[str]):
+        """Given a set of remote users check if the server still shares a room with
+        them. If not then mark those users' device cache as stale.
+        """
+
+        if not user_ids:
+            return
+
+        joined_users = await self.main_store.get_users_server_still_shares_room_with(
+            user_ids
+        )
+        left_users = user_ids - joined_users
+
+        for user_id in left_users:
+            await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)