summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2022-08-18 11:53:02 +0100
committerGitHub <noreply@github.com>2022-08-18 11:53:02 +0100
commit84169a82dcf7dfb6eb7d307ea7f5e33cb57f6e3f (patch)
tree4aeefdf294bf62977997040c06b0a6889b38b05c /synapse/handlers
parentAdd metrics to track how the rate limiter is affecting requests (sleep/reject... (diff)
downloadsynapse-84169a82dcf7dfb6eb7d307ea7f5e33cb57f6e3f.tar.xz
Avoid blocking lazy-loading `/sync`s during partial joins (#13477)
Use a state filter or accept partial state in a few places where we
request state, to avoid blocking.

To make lazy-loading `/sync`s work, we need to provide the memberships
of event senders, which are not guaranteed to be in the room state.
Instead we dig through auth events for memberships to present to
clients. The auth events of an event are guaranteed to contain a
passable membership event, otherwise the event would have been rejected.

Note that this only covers the common code paths encountered during
testing. There has been no exhaustive checking of all sync code paths.

Fixes #13146.

Signed-off-by: Sean Quah <seanq@matrix.org>
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/sync.py253
1 files changed, 223 insertions, 30 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 3ca01391c9..b4d3f3958c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -16,9 +16,11 @@ import logging
 from typing import (
     TYPE_CHECKING,
     Any,
+    Collection,
     Dict,
     FrozenSet,
     List,
+    Mapping,
     Optional,
     Sequence,
     Set,
@@ -517,10 +519,17 @@ class SyncHandler:
                 # ensure that we always include current state in the timeline
                 current_state_ids: FrozenSet[str] = frozenset()
                 if any(e.is_state() for e in recents):
+                    # FIXME(faster_joins): We use the partial state here as
+                    # we don't want to block `/sync` on finishing a lazy join.
+                    # Which should be fine once
+                    # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+                    # since we shouldn't reach here anymore?
+                    # Note that we use the current state as a whitelist for filtering
+                    # `recents`, so partial state is only a problem when a membership
+                    # event turns up in `recents` but has not made it into the current
+                    # state.
                     current_state_ids_map = (
-                        await self._state_storage_controller.get_current_state_ids(
-                            room_id
-                        )
+                        await self.store.get_partial_current_state_ids(room_id)
                     )
                     current_state_ids = frozenset(current_state_ids_map.values())
 
@@ -589,7 +598,13 @@ class SyncHandler:
                 if any(e.is_state() for e in loaded_recents):
                     # FIXME(faster_joins): We use the partial state here as
                     # we don't want to block `/sync` on finishing a lazy join.
-                    # Is this the correct way of doing it?
+                    # Which should be fine once
+                    # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+                    # since we shouldn't reach here anymore?
+                    # Note that we use the current state as a whitelist for filtering
+                    # `loaded_recents`, so partial state is only a problem when a
+                    # membership event turns up in `loaded_recents` but has not made it
+                    # into the current state.
                     current_state_ids_map = (
                         await self.store.get_partial_current_state_ids(room_id)
                     )
@@ -637,7 +652,10 @@ class SyncHandler:
         )
 
     async def get_state_after_event(
-        self, event_id: str, state_filter: Optional[StateFilter] = None
+        self,
+        event_id: str,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """
         Get the room state after the given event
@@ -645,9 +663,14 @@ class SyncHandler:
         Args:
             event_id: event of interest
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the event and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
         """
         state_ids = await self._state_storage_controller.get_state_ids_for_event(
-            event_id, state_filter=state_filter or StateFilter.all()
+            event_id,
+            state_filter=state_filter or StateFilter.all(),
+            await_full_state=await_full_state,
         )
 
         # using get_metadata_for_events here (instead of get_event) sidesteps an issue
@@ -670,6 +693,7 @@ class SyncHandler:
         room_id: str,
         stream_position: StreamToken,
         state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """Get the room state at a particular stream position
 
@@ -677,6 +701,9 @@ class SyncHandler:
             room_id: room for which to get state
             stream_position: point at which to get state
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the last event in the room before `stream_position` and
+                `state_filter` is not satisfied by partial state. Defaults to `True`.
         """
         # FIXME: This gets the state at the latest event before the stream ordering,
         # which might not be the same as the "current state" of the room at the time
@@ -688,7 +715,9 @@ class SyncHandler:
 
         if last_event_id:
             state = await self.get_state_after_event(
-                last_event_id, state_filter=state_filter or StateFilter.all()
+                last_event_id,
+                state_filter=state_filter or StateFilter.all(),
+                await_full_state=await_full_state,
             )
 
         else:
@@ -891,7 +920,15 @@ class SyncHandler:
         with Measure(self.clock, "compute_state_delta"):
             # The memberships needed for events in the timeline.
             # Only calculated when `lazy_load_members` is on.
-            members_to_fetch = None
+            members_to_fetch: Optional[Set[str]] = None
+
+            # A dictionary mapping user IDs to the first event in the timeline sent by
+            # them. Only calculated when `lazy_load_members` is on.
+            first_event_by_sender_map: Optional[Dict[str, EventBase]] = None
+
+            # The contribution to the room state from state events in the timeline.
+            # Only contains the last event for any given state key.
+            timeline_state: StateMap[str]
 
             lazy_load_members = sync_config.filter_collection.lazy_load_members()
             include_redundant_members = (
@@ -902,10 +939,23 @@ class SyncHandler:
                 # We only request state for the members needed to display the
                 # timeline:
 
-                members_to_fetch = {
-                    event.sender  # FIXME: we also care about invite targets etc.
-                    for event in batch.events
-                }
+                timeline_state = {}
+
+                members_to_fetch = set()
+                first_event_by_sender_map = {}
+                for event in batch.events:
+                    # Build the map from user IDs to the first timeline event they sent.
+                    if event.sender not in first_event_by_sender_map:
+                        first_event_by_sender_map[event.sender] = event
+
+                    # We need the event's sender, unless their membership was in a
+                    # previous timeline event.
+                    if (EventTypes.Member, event.sender) not in timeline_state:
+                        members_to_fetch.add(event.sender)
+                    # FIXME: we also care about invite targets etc.
+
+                    if event.is_state():
+                        timeline_state[(event.type, event.state_key)] = event.event_id
 
                 if full_state:
                     # always make sure we LL ourselves so we know we're in the room
@@ -915,16 +965,21 @@ class SyncHandler:
                     members_to_fetch.add(sync_config.user.to_string())
 
                 state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
+
+                # We are happy to use partial state to compute the `/sync` response.
+                # Since partial state may not include the lazy-loaded memberships we
+                # require, we fix up the state response afterwards with memberships from
+                # auth events.
+                await_full_state = False
             else:
-                state_filter = StateFilter.all()
+                timeline_state = {
+                    (event.type, event.state_key): event.event_id
+                    for event in batch.events
+                    if event.is_state()
+                }
 
-            # The contribution to the room state from state events in the timeline.
-            # Only contains the last event for any given state key.
-            timeline_state = {
-                (event.type, event.state_key): event.event_id
-                for event in batch.events
-                if event.is_state()
-            }
+                state_filter = StateFilter.all()
+                await_full_state = True
 
             # Now calculate the state to return in the sync response for the room.
             # This is more or less the change in state between the end of the previous
@@ -936,19 +991,26 @@ class SyncHandler:
                 if batch:
                     state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[-1].event_id, state_filter=state_filter
+                            batch.events[-1].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
 
                     state_at_timeline_start = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[0].event_id, state_filter=state_filter
+                            batch.events[0].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
 
                 else:
                     state_at_timeline_end = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
                     state_at_timeline_start = state_at_timeline_end
@@ -964,14 +1026,19 @@ class SyncHandler:
                 if batch:
                     state_at_timeline_start = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[0].event_id, state_filter=state_filter
+                            batch.events[0].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
                 else:
                     # We can get here if the user has ignored the senders of all
                     # the recent events.
                     state_at_timeline_start = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
                 # for now, we disable LL for gappy syncs - see
@@ -993,20 +1060,28 @@ class SyncHandler:
                 # is indeed the case.
                 assert since_token is not None
                 state_at_previous_sync = await self.get_state_at(
-                    room_id, stream_position=since_token, state_filter=state_filter
+                    room_id,
+                    stream_position=since_token,
+                    state_filter=state_filter,
+                    await_full_state=await_full_state,
                 )
 
                 if batch:
                     state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[-1].event_id, state_filter=state_filter
+                            batch.events[-1].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
                 else:
                     # We can get here if the user has ignored the senders of all
                     # the recent events.
                     state_at_timeline_end = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
                 state_ids = _calculate_state(
@@ -1036,8 +1111,23 @@ class SyncHandler:
                                 (EventTypes.Member, member)
                                 for member in members_to_fetch
                             ),
+                            await_full_state=False,
                         )
 
+            # If we only have partial state for the room, `state_ids` may be missing the
+            # memberships we wanted. We attempt to find some by digging through the auth
+            # events of timeline events.
+            if lazy_load_members and await self.store.is_partial_state_room(room_id):
+                assert members_to_fetch is not None
+                assert first_event_by_sender_map is not None
+
+                additional_state_ids = (
+                    await self._find_missing_partial_state_memberships(
+                        room_id, members_to_fetch, first_event_by_sender_map, state_ids
+                    )
+                )
+                state_ids = {**state_ids, **additional_state_ids}
+
             # At this point, if `lazy_load_members` is enabled, `state_ids` includes
             # the memberships of all event senders in the timeline. This is because we
             # may not have sent the memberships in a previous sync.
@@ -1086,6 +1176,99 @@ class SyncHandler:
             if e.type != EventTypes.Aliases  # until MSC2261 or alternative solution
         }
 
+    async def _find_missing_partial_state_memberships(
+        self,
+        room_id: str,
+        members_to_fetch: Collection[str],
+        events_with_membership_auth: Mapping[str, EventBase],
+        found_state_ids: StateMap[str],
+    ) -> StateMap[str]:
+        """Finds missing memberships from a set of auth events and returns them as a
+        state map.
+
+        Args:
+            room_id: The partial state room to find the remaining memberships for.
+            members_to_fetch: The memberships to find.
+            events_with_membership_auth: A mapping from user IDs to events whose auth
+                events are known to contain their membership.
+            found_state_ids: A dict from (type, state_key) -> state_event_id, containing
+                memberships that have been previously found. Entries in
+                `members_to_fetch` that have a membership in `found_state_ids` are
+                ignored.
+
+        Returns:
+            A dict from ("m.room.member", state_key) -> state_event_id, containing the
+            memberships missing from `found_state_ids`.
+
+        Raises:
+            KeyError: if `events_with_membership_auth` does not have an entry for a
+                missing membership. Memberships in `found_state_ids` do not need an
+                entry in `events_with_membership_auth`.
+        """
+        additional_state_ids: MutableStateMap[str] = {}
+
+        # Tracks the missing members for logging purposes.
+        missing_members = set()
+
+        # Identify memberships missing from `found_state_ids` and pick out the auth
+        # events in which to look for them.
+        auth_event_ids: Set[str] = set()
+        for member in members_to_fetch:
+            if (EventTypes.Member, member) in found_state_ids:
+                continue
+
+            missing_members.add(member)
+            event_with_membership_auth = events_with_membership_auth[member]
+            auth_event_ids.update(event_with_membership_auth.auth_event_ids())
+
+        auth_events = await self.store.get_events(auth_event_ids)
+
+        # Run through the missing memberships once more, picking out the memberships
+        # from the pile of auth events we have just fetched.
+        for member in members_to_fetch:
+            if (EventTypes.Member, member) in found_state_ids:
+                continue
+
+            event_with_membership_auth = events_with_membership_auth[member]
+
+            # Dig through the auth events to find the desired membership.
+            for auth_event_id in event_with_membership_auth.auth_event_ids():
+                # We only store events once we have all their auth events,
+                # so the auth event must be in the pile we have just
+                # fetched.
+                auth_event = auth_events[auth_event_id]
+
+                if (
+                    auth_event.type == EventTypes.Member
+                    and auth_event.state_key == member
+                ):
+                    missing_members.remove(member)
+                    additional_state_ids[
+                        (EventTypes.Member, member)
+                    ] = auth_event.event_id
+                    break
+
+        if missing_members:
+            # There really shouldn't be any missing memberships now. Either:
+            #  * we couldn't find an auth event, which shouldn't happen because we do
+            #    not persist events with persisting their auth events first, or
+            #  * the set of auth events did not contain a membership we wanted, which
+            #    means our caller didn't compute the events in `members_to_fetch`
+            #    correctly, or we somehow accepted an event whose auth events were
+            #    dodgy.
+            logger.error(
+                "Failed to find memberships for %s in partial state room "
+                "%s in the auth events of %s.",
+                missing_members,
+                room_id,
+                [
+                    events_with_membership_auth[member].event_id
+                    for member in missing_members
+                ],
+            )
+
+        return additional_state_ids
+
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
     ) -> NotifCounts:
@@ -1730,7 +1913,11 @@ class SyncHandler:
                 continue
 
             if room_id in sync_result_builder.joined_room_ids or has_join:
-                old_state_ids = await self.get_state_at(room_id, since_token)
+                old_state_ids = await self.get_state_at(
+                    room_id,
+                    since_token,
+                    state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
+                )
                 old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
                 old_mem_ev = None
                 if old_mem_ev_id:
@@ -1756,7 +1943,13 @@ class SyncHandler:
                     newly_left_rooms.append(room_id)
                 else:
                     if not old_state_ids:
-                        old_state_ids = await self.get_state_at(room_id, since_token)
+                        old_state_ids = await self.get_state_at(
+                            room_id,
+                            since_token,
+                            state_filter=StateFilter.from_types(
+                                [(EventTypes.Member, user_id)]
+                            ),
+                        )
                         old_mem_ev_id = old_state_ids.get(
                             (EventTypes.Member, user_id), None
                         )