summary refs log tree commit diff
path: root/synapse/handlers/sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sync.py')
-rw-r--r--synapse/handlers/sync.py361
1 files changed, 297 insertions, 64 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index cddfb4cec7..2991967226 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,7 +13,19 @@
 # limitations under the License.
 import itertools
 import logging
-from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    FrozenSet,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 import attr
 from prometheus_client import Counter
@@ -94,7 +106,7 @@ class SyncConfig:
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class TimelineBatch:
     prev_batch: StreamToken
-    events: List[EventBase]
+    events: Sequence[EventBase]
     limited: bool
     # A mapping of event ID to the bundled aggregations for the above events.
     # This is only calculated if limited is true.
@@ -512,10 +524,17 @@ class SyncHandler:
                 # ensure that we always include current state in the timeline
                 current_state_ids: FrozenSet[str] = frozenset()
                 if any(e.is_state() for e in recents):
+                    # FIXME(faster_joins): We use the partial state here as
+                    # we don't want to block `/sync` on finishing a lazy join.
+                    # Which should be fine once
+                    # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+                    # since we shouldn't reach here anymore?
+                    # Note that we use the current state as a whitelist for filtering
+                    # `recents`, so partial state is only a problem when a membership
+                    # event turns up in `recents` but has not made it into the current
+                    # state.
                     current_state_ids_map = (
-                        await self._state_storage_controller.get_current_state_ids(
-                            room_id
-                        )
+                        await self.store.get_partial_current_state_ids(room_id)
                     )
                     current_state_ids = frozenset(current_state_ids_map.values())
 
@@ -584,7 +603,13 @@ class SyncHandler:
                 if any(e.is_state() for e in loaded_recents):
                     # FIXME(faster_joins): We use the partial state here as
                     # we don't want to block `/sync` on finishing a lazy join.
-                    # Is this the correct way of doing it?
+                    # Which should be fine once
+                    # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+                    # since we shouldn't reach here anymore?
+                    # Note that we use the current state as a whitelist for filtering
+                    # `loaded_recents`, so partial state is only a problem when a
+                    # membership event turns up in `loaded_recents` but has not made it
+                    # into the current state.
                     current_state_ids_map = (
                         await self.store.get_partial_current_state_ids(room_id)
                     )
@@ -632,7 +657,10 @@ class SyncHandler:
         )
 
     async def get_state_after_event(
-        self, event_id: str, state_filter: Optional[StateFilter] = None
+        self,
+        event_id: str,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """
         Get the room state after the given event
@@ -640,9 +668,14 @@ class SyncHandler:
         Args:
             event_id: event of interest
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the event and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
         """
         state_ids = await self._state_storage_controller.get_state_ids_for_event(
-            event_id, state_filter=state_filter or StateFilter.all()
+            event_id,
+            state_filter=state_filter or StateFilter.all(),
+            await_full_state=await_full_state,
         )
 
         # using get_metadata_for_events here (instead of get_event) sidesteps an issue
@@ -665,6 +698,7 @@ class SyncHandler:
         room_id: str,
         stream_position: StreamToken,
         state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """Get the room state at a particular stream position
 
@@ -672,6 +706,9 @@ class SyncHandler:
             room_id: room for which to get state
             stream_position: point at which to get state
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the last event in the room before `stream_position` and
+                `state_filter` is not satisfied by partial state. Defaults to `True`.
         """
         # FIXME: This gets the state at the latest event before the stream ordering,
         # which might not be the same as the "current state" of the room at the time
@@ -683,7 +720,9 @@ class SyncHandler:
 
         if last_event_id:
             state = await self.get_state_after_event(
-                last_event_id, state_filter=state_filter or StateFilter.all()
+                last_event_id,
+                state_filter=state_filter or StateFilter.all(),
+                await_full_state=await_full_state,
             )
 
         else:
@@ -857,16 +896,26 @@ class SyncHandler:
         now_token: StreamToken,
         full_state: bool,
     ) -> MutableStateMap[EventBase]:
-        """Works out the difference in state between the start of the timeline
-        and the previous sync.
+        """Works out the difference in state between the end of the previous sync and
+        the start of the timeline.
 
         Args:
             room_id:
             batch: The timeline batch for the room that will be sent to the user.
             sync_config:
-            since_token: Token of the end of the previous batch. May be None.
+            since_token: Token of the end of the previous batch. May be `None`.
             now_token: Token of the end of the current batch.
             full_state: Whether to force returning the full state.
+                `lazy_load_members` still applies when `full_state` is `True`.
+
+        Returns:
+            The state to return in the sync response for the room.
+
+            Clients will overlay this onto the state at the end of the previous sync to
+            arrive at the state at the start of the timeline.
+
+            Clients will then overlay state events in the timeline to arrive at the
+            state at the end of the timeline, in preparation for the next sync.
         """
         # TODO(mjark) Check if the state events were received by the server
         # after the previous sync, since we need to include those state
@@ -874,8 +923,17 @@ class SyncHandler:
         # TODO(mjark) Check for new redactions in the state events.
 
         with Measure(self.clock, "compute_state_delta"):
+            # The memberships needed for events in the timeline.
+            # Only calculated when `lazy_load_members` is on.
+            members_to_fetch: Optional[Set[str]] = None
+
+            # A dictionary mapping user IDs to the first event in the timeline sent by
+            # them. Only calculated when `lazy_load_members` is on.
+            first_event_by_sender_map: Optional[Dict[str, EventBase]] = None
 
-            members_to_fetch = None
+            # The contribution to the room state from state events in the timeline.
+            # Only contains the last event for any given state key.
+            timeline_state: StateMap[str]
 
             lazy_load_members = sync_config.filter_collection.lazy_load_members()
             include_redundant_members = (
@@ -886,10 +944,23 @@ class SyncHandler:
                 # We only request state for the members needed to display the
                 # timeline:
 
-                members_to_fetch = {
-                    event.sender  # FIXME: we also care about invite targets etc.
-                    for event in batch.events
-                }
+                timeline_state = {}
+
+                members_to_fetch = set()
+                first_event_by_sender_map = {}
+                for event in batch.events:
+                    # Build the map from user IDs to the first timeline event they sent.
+                    if event.sender not in first_event_by_sender_map:
+                        first_event_by_sender_map[event.sender] = event
+
+                    # We need the event's sender, unless their membership was in a
+                    # previous timeline event.
+                    if (EventTypes.Member, event.sender) not in timeline_state:
+                        members_to_fetch.add(event.sender)
+                    # FIXME: we also care about invite targets etc.
+
+                    if event.is_state():
+                        timeline_state[(event.type, event.state_key)] = event.event_id
 
                 if full_state:
                     # always make sure we LL ourselves so we know we're in the room
@@ -899,55 +970,80 @@ class SyncHandler:
                     members_to_fetch.add(sync_config.user.to_string())
 
                 state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
+
+                # We are happy to use partial state to compute the `/sync` response.
+                # Since partial state may not include the lazy-loaded memberships we
+                # require, we fix up the state response afterwards with memberships from
+                # auth events.
+                await_full_state = False
             else:
+                timeline_state = {
+                    (event.type, event.state_key): event.event_id
+                    for event in batch.events
+                    if event.is_state()
+                }
+
                 state_filter = StateFilter.all()
+                await_full_state = True
 
-            timeline_state = {
-                (event.type, event.state_key): event.event_id
-                for event in batch.events
-                if event.is_state()
-            }
+            # Now calculate the state to return in the sync response for the room.
+            # This is more or less the change in state between the end of the previous
+            # sync's timeline and the start of the current sync's timeline.
+            # See the docstring above for details.
+            state_ids: StateMap[str]
 
             if full_state:
                 if batch:
-                    current_state_ids = (
+                    state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[-1].event_id, state_filter=state_filter
+                            batch.events[-1].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
 
-                    state_ids = (
+                    state_at_timeline_start = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[0].event_id, state_filter=state_filter
+                            batch.events[0].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
 
                 else:
-                    current_state_ids = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                    state_at_timeline_end = await self.get_state_at(
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
-                    state_ids = current_state_ids
+                    state_at_timeline_start = state_at_timeline_end
 
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
-                    timeline_start=state_ids,
-                    previous={},
-                    current=current_state_ids,
+                    timeline_start=state_at_timeline_start,
+                    timeline_end=state_at_timeline_end,
+                    previous_timeline_end={},
                     lazy_load_members=lazy_load_members,
                 )
             elif batch.limited:
                 if batch:
                     state_at_timeline_start = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[0].event_id, state_filter=state_filter
+                            batch.events[0].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
                 else:
                     # We can get here if the user has ignored the senders of all
                     # the recent events.
                     state_at_timeline_start = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
                 # for now, we disable LL for gappy syncs - see
@@ -969,28 +1065,35 @@ class SyncHandler:
                 # is indeed the case.
                 assert since_token is not None
                 state_at_previous_sync = await self.get_state_at(
-                    room_id, stream_position=since_token, state_filter=state_filter
+                    room_id,
+                    stream_position=since_token,
+                    state_filter=state_filter,
+                    await_full_state=await_full_state,
                 )
 
                 if batch:
-                    current_state_ids = (
+                    state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[-1].event_id, state_filter=state_filter
+                            batch.events[-1].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
                 else:
-                    # Its not clear how we get here, but empirically we do
-                    # (#5407). Logging has been added elsewhere to try and
-                    # figure out where this state comes from.
-                    current_state_ids = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                    # We can get here if the user has ignored the senders of all
+                    # the recent events.
+                    state_at_timeline_end = await self.get_state_at(
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
                     timeline_start=state_at_timeline_start,
-                    previous=state_at_previous_sync,
-                    current=current_state_ids,
+                    timeline_end=state_at_timeline_end,
+                    previous_timeline_end=state_at_previous_sync,
                     # we have to include LL members in case LL initial sync missed them
                     lazy_load_members=lazy_load_members,
                 )
@@ -1013,8 +1116,30 @@ class SyncHandler:
                                 (EventTypes.Member, member)
                                 for member in members_to_fetch
                             ),
+                            await_full_state=False,
                         )
 
+            # If we only have partial state for the room, `state_ids` may be missing the
+            # memberships we wanted. We attempt to find some by digging through the auth
+            # events of timeline events.
+            if lazy_load_members and await self.store.is_partial_state_room(room_id):
+                assert members_to_fetch is not None
+                assert first_event_by_sender_map is not None
+
+                additional_state_ids = (
+                    await self._find_missing_partial_state_memberships(
+                        room_id, members_to_fetch, first_event_by_sender_map, state_ids
+                    )
+                )
+                state_ids = {**state_ids, **additional_state_ids}
+
+            # At this point, if `lazy_load_members` is enabled, `state_ids` includes
+            # the memberships of all event senders in the timeline. This is because we
+            # may not have sent the memberships in a previous sync.
+
+            # When `include_redundant_members` is on, we send all the lazy-loaded
+            # memberships of event senders. Otherwise we make an effort to limit the set
+            # of memberships we send to those that we have not already sent to this client.
             if lazy_load_members and not include_redundant_members:
                 cache_key = (sync_config.user.to_string(), sync_config.device_id)
                 cache = self.get_lazy_loaded_members_cache(cache_key)
@@ -1056,6 +1181,99 @@ class SyncHandler:
             if e.type != EventTypes.Aliases  # until MSC2261 or alternative solution
         }
 
+    async def _find_missing_partial_state_memberships(
+        self,
+        room_id: str,
+        members_to_fetch: Collection[str],
+        events_with_membership_auth: Mapping[str, EventBase],
+        found_state_ids: StateMap[str],
+    ) -> StateMap[str]:
+        """Finds missing memberships from a set of auth events and returns them as a
+        state map.
+
+        Args:
+            room_id: The partial state room to find the remaining memberships for.
+            members_to_fetch: The memberships to find.
+            events_with_membership_auth: A mapping from user IDs to events whose auth
+                events are known to contain their membership.
+            found_state_ids: A dict from (type, state_key) -> state_event_id, containing
+                memberships that have been previously found. Entries in
+                `members_to_fetch` that have a membership in `found_state_ids` are
+                ignored.
+
+        Returns:
+            A dict from ("m.room.member", state_key) -> state_event_id, containing the
+            memberships missing from `found_state_ids`.
+
+        Raises:
+            KeyError: if `events_with_membership_auth` does not have an entry for a
+                missing membership. Memberships in `found_state_ids` do not need an
+                entry in `events_with_membership_auth`.
+        """
+        additional_state_ids: MutableStateMap[str] = {}
+
+        # Tracks the missing members for logging purposes.
+        missing_members = set()
+
+        # Identify memberships missing from `found_state_ids` and pick out the auth
+        # events in which to look for them.
+        auth_event_ids: Set[str] = set()
+        for member in members_to_fetch:
+            if (EventTypes.Member, member) in found_state_ids:
+                continue
+
+            missing_members.add(member)
+            event_with_membership_auth = events_with_membership_auth[member]
+            auth_event_ids.update(event_with_membership_auth.auth_event_ids())
+
+        auth_events = await self.store.get_events(auth_event_ids)
+
+        # Run through the missing memberships once more, picking out the memberships
+        # from the pile of auth events we have just fetched.
+        for member in members_to_fetch:
+            if (EventTypes.Member, member) in found_state_ids:
+                continue
+
+            event_with_membership_auth = events_with_membership_auth[member]
+
+            # Dig through the auth events to find the desired membership.
+            for auth_event_id in event_with_membership_auth.auth_event_ids():
+                # We only store events once we have all their auth events,
+                # so the auth event must be in the pile we have just
+                # fetched.
+                auth_event = auth_events[auth_event_id]
+
+                if (
+                    auth_event.type == EventTypes.Member
+                    and auth_event.state_key == member
+                ):
+                    missing_members.remove(member)
+                    additional_state_ids[
+                        (EventTypes.Member, member)
+                    ] = auth_event.event_id
+                    break
+
+        if missing_members:
+            # There really shouldn't be any missing memberships now. Either:
+            #  * we couldn't find an auth event, which shouldn't happen because we do
+            #    not persist events with persisting their auth events first, or
+            #  * the set of auth events did not contain a membership we wanted, which
+            #    means our caller didn't compute the events in `members_to_fetch`
+            #    correctly, or we somehow accepted an event whose auth events were
+            #    dodgy.
+            logger.error(
+                "Failed to find memberships for %s in partial state room "
+                "%s in the auth events of %s.",
+                missing_members,
+                room_id,
+                [
+                    events_with_membership_auth[member].event_id
+                    for member in missing_members
+                ],
+            )
+
+        return additional_state_ids
+
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
     ) -> NotifCounts:
@@ -1700,7 +1918,11 @@ class SyncHandler:
                 continue
 
             if room_id in sync_result_builder.joined_room_ids or has_join:
-                old_state_ids = await self.get_state_at(room_id, since_token)
+                old_state_ids = await self.get_state_at(
+                    room_id,
+                    since_token,
+                    state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
+                )
                 old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
                 old_mem_ev = None
                 if old_mem_ev_id:
@@ -1726,7 +1948,13 @@ class SyncHandler:
                     newly_left_rooms.append(room_id)
                 else:
                     if not old_state_ids:
-                        old_state_ids = await self.get_state_at(room_id, since_token)
+                        old_state_ids = await self.get_state_at(
+                            room_id,
+                            since_token,
+                            state_filter=StateFilter.from_types(
+                                [(EventTypes.Member, user_id)]
+                            ),
+                        )
                         old_mem_ev_id = old_state_ids.get(
                             (EventTypes.Member, user_id), None
                         )
@@ -2221,8 +2449,8 @@ def _action_has_highlight(actions: List[JsonDict]) -> bool:
 def _calculate_state(
     timeline_contains: StateMap[str],
     timeline_start: StateMap[str],
-    previous: StateMap[str],
-    current: StateMap[str],
+    timeline_end: StateMap[str],
+    previous_timeline_end: StateMap[str],
     lazy_load_members: bool,
 ) -> StateMap[str]:
     """Works out what state to include in a sync response.
@@ -2230,45 +2458,50 @@ def _calculate_state(
     Args:
         timeline_contains: state in the timeline
         timeline_start: state at the start of the timeline
-        previous: state at the end of the previous sync (or empty dict
+        timeline_end: state at the end of the timeline
+        previous_timeline_end: state at the end of the previous sync (or empty dict
             if this is an initial sync)
-        current: state at the end of the timeline
         lazy_load_members: whether to return members from timeline_start
             or not.  assumes that timeline_start has already been filtered to
             include only the members the client needs to know about.
     """
-    event_id_to_key = {
-        e: key
-        for key, e in itertools.chain(
+    event_id_to_state_key = {
+        event_id: state_key
+        for state_key, event_id in itertools.chain(
             timeline_contains.items(),
-            previous.items(),
             timeline_start.items(),
-            current.items(),
+            timeline_end.items(),
+            previous_timeline_end.items(),
         )
     }
 
-    c_ids = set(current.values())
-    ts_ids = set(timeline_start.values())
-    p_ids = set(previous.values())
-    tc_ids = set(timeline_contains.values())
+    timeline_end_ids = set(timeline_end.values())
+    timeline_start_ids = set(timeline_start.values())
+    previous_timeline_end_ids = set(previous_timeline_end.values())
+    timeline_contains_ids = set(timeline_contains.values())
 
     # If we are lazyloading room members, we explicitly add the membership events
     # for the senders in the timeline into the state block returned by /sync,
     # as we may not have sent them to the client before.  We find these membership
     # events by filtering them out of timeline_start, which has already been filtered
     # to only include membership events for the senders in the timeline.
-    # In practice, we can do this by removing them from the p_ids list,
-    # which is the list of relevant state we know we have already sent to the client.
+    # In practice, we can do this by removing them from the previous_timeline_end_ids
+    # list, which is the list of relevant state we know we have already sent to the
+    # client.
     # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
 
     if lazy_load_members:
-        p_ids.difference_update(
+        previous_timeline_end_ids.difference_update(
             e for t, e in timeline_start.items() if t[0] == EventTypes.Member
         )
 
-    state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
+    state_ids = (
+        (timeline_end_ids | timeline_start_ids)
+        - previous_timeline_end_ids
+        - timeline_contains_ids
+    )
 
-    return {event_id_to_key[e]: e for e in state_ids}
+    return {event_id_to_state_key[e]: e for e in state_ids}
 
 
 @attr.s(slots=True, auto_attribs=True)