summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/federation.py273
1 files changed, 40 insertions, 233 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 246df43501..6fa2fc8f52 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -65,6 +65,7 @@ from synapse.event_auth import auth_types_for_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
+from synapse.federation.federation_client import InvalidResponseError
 from synapse.handlers._base import BaseHandler
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import (
@@ -116,10 +117,6 @@ class _NewEventInfo:
     Attributes:
         event: the received event
 
-        state: the state at that event, according to /state_ids from a remote
-           homeserver. Only populated for backfilled events which are going to be a
-           new backwards extremity.
-
         claimed_auth_event_map: a map of (type, state_key) => event for the event's
             claimed auth_events.
 
@@ -134,7 +131,6 @@ class _NewEventInfo:
     """
 
     event: EventBase
-    state: Optional[Sequence[EventBase]]
     claimed_auth_event_map: StateMap[EventBase]
 
 
@@ -443,113 +439,7 @@ class FederationHandler(BaseHandler):
             return
 
         logger.info("Got %d prev_events", len(missing_events))
-        await self._process_pulled_events(origin, missing_events)
-
-    async def _get_state_for_room(
-        self,
-        destination: str,
-        room_id: str,
-        event_id: str,
-    ) -> List[EventBase]:
-        """Requests all of the room state at a given event from a remote
-        homeserver.
-
-        Will also fetch any missing events reported in the `auth_chain_ids`
-        section of `/state_ids`.
-
-        Args:
-            destination: The remote homeserver to query for the state.
-            room_id: The id of the room we're interested in.
-            event_id: The id of the event we want the state at.
-
-        Returns:
-            A list of events in the state, not including the event itself.
-        """
-        (
-            state_event_ids,
-            auth_event_ids,
-        ) = await self.federation_client.get_room_state_ids(
-            destination, room_id, event_id=event_id
-        )
-
-        # Fetch the state events from the DB, and check we have the auth events.
-        event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
-        auth_events_in_store = await self.store.have_seen_events(
-            room_id, auth_event_ids
-        )
-
-        # Check for missing events. We handle state and auth event seperately,
-        # as we want to pull the state from the DB, but we don't for the auth
-        # events. (Note: we likely won't use the majority of the auth chain, and
-        # it can be *huge* for large rooms, so it's worth ensuring that we don't
-        # unnecessarily pull it from the DB).
-        missing_state_events = set(state_event_ids) - set(event_map)
-        missing_auth_events = set(auth_event_ids) - set(auth_events_in_store)
-        if missing_state_events or missing_auth_events:
-            await self._get_events_and_persist(
-                destination=destination,
-                room_id=room_id,
-                events=missing_state_events | missing_auth_events,
-            )
-
-            if missing_state_events:
-                new_events = await self.store.get_events(
-                    missing_state_events, allow_rejected=True
-                )
-                event_map.update(new_events)
-
-                missing_state_events.difference_update(new_events)
-
-                if missing_state_events:
-                    logger.warning(
-                        "Failed to fetch missing state events for %s %s",
-                        event_id,
-                        missing_state_events,
-                    )
-
-            if missing_auth_events:
-                auth_events_in_store = await self.store.have_seen_events(
-                    room_id, missing_auth_events
-                )
-                missing_auth_events.difference_update(auth_events_in_store)
-
-                if missing_auth_events:
-                    logger.warning(
-                        "Failed to fetch missing auth events for %s %s",
-                        event_id,
-                        missing_auth_events,
-                    )
-
-        remote_state = list(event_map.values())
-
-        # check for events which were in the wrong room.
-        #
-        # this can happen if a remote server claims that the state or
-        # auth_events at an event in room A are actually events in room B
-
-        bad_events = [
-            (event.event_id, event.room_id)
-            for event in remote_state
-            if event.room_id != room_id
-        ]
-
-        for bad_event_id, bad_room_id in bad_events:
-            # This is a bogus situation, but since we may only discover it a long time
-            # after it happened, we try our best to carry on, by just omitting the
-            # bad events from the returned auth/state set.
-            logger.warning(
-                "Remote server %s claims event %s in room %s is an auth/state "
-                "event in room %s",
-                destination,
-                bad_event_id,
-                bad_room_id,
-                room_id,
-            )
-
-        if bad_events:
-            remote_state = [e for e in remote_state if e.room_id == room_id]
-
-        return remote_state
+        await self._process_pulled_events(origin, missing_events, backfilled=False)
 
     async def _get_state_after_missing_prev_event(
         self,
@@ -567,10 +457,6 @@ class FederationHandler(BaseHandler):
         Returns:
             A list of events in the state, including the event itself
         """
-        # TODO: This function is basically the same as _get_state_for_room. Can
-        #   we make backfill() use it, rather than having two code paths? I think the
-        #   only difference is that backfill() persists the prev events separately.
-
         (
             state_event_ids,
             auth_event_ids,
@@ -681,6 +567,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         state: Optional[Iterable[EventBase]],
+        backfilled: bool = False,
     ) -> None:
         """Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
@@ -693,6 +580,9 @@ class FederationHandler(BaseHandler):
             state: Normally None, but if we are handling a gap in the graph
                 (ie, we are missing one or more prev_events), the resolved state at the
                 event
+
+            backfilled: True if this is part of a historical batch of events (inhibits
+                notification to clients, and validation of device keys.)
         """
         logger.debug("Processing event: %s", event)
 
@@ -700,10 +590,15 @@ class FederationHandler(BaseHandler):
             context = await self.state_handler.compute_event_context(
                 event, old_state=state
             )
-            await self._auth_and_persist_event(origin, event, context, state=state)
+            await self._auth_and_persist_event(
+                origin, event, context, state=state, backfilled=backfilled
+            )
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
+        if backfilled:
+            return
+
         # For encrypted messages we check that we know about the sending device,
         # if we don't then we mark the device cache for that user as stale.
         if event.type == EventTypes.Encrypted:
@@ -868,7 +763,7 @@ class FederationHandler(BaseHandler):
     @log_function
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: List[str]
-    ) -> List[EventBase]:
+    ) -> None:
         """Trigger a backfill request to `dest` for the given `room_id`
 
         This will attempt to get more events from the remote. If the other side
@@ -878,6 +773,9 @@ class FederationHandler(BaseHandler):
         sanity-checking on them. If any of the backfilled events are invalid,
         this method throws a SynapseError.
 
+        We might also raise an InvalidResponseError if the response from the remote
+        server is just bogus.
+
         TODO: make this more useful to distinguish failures of the remote
         server from invalid events (there is probably no point in trying to
         re-fetch invalid events from every other HS in the room.)
@@ -890,111 +788,18 @@ class FederationHandler(BaseHandler):
         )
 
         if not events:
-            return []
-
-        # ideally we'd sanity check the events here for excess prev_events etc,
-        # but it's hard to reject events at this point without completely
-        # breaking backfill in the same way that it is currently broken by
-        # events whose signature we cannot verify (#3121).
-        #
-        # So for now we accept the events anyway. #3124 tracks this.
-        #
-        # for ev in events:
-        #     self._sanity_check_event(ev)
-
-        # Don't bother processing events we already have.
-        seen_events = await self.store.have_events_in_timeline(
-            {e.event_id for e in events}
-        )
-
-        events = [e for e in events if e.event_id not in seen_events]
-
-        if not events:
-            return []
-
-        event_map = {e.event_id: e for e in events}
-
-        event_ids = {e.event_id for e in events}
-
-        # build a list of events whose prev_events weren't in the batch.
-        # (XXX: this will include events whose prev_events we already have; that doesn't
-        # sound right?)
-        edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
-
-        logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
-
-        # For each edge get the current state.
-
-        state_events = {}
-        events_to_state = {}
-        for e_id in edges:
-            state = await self._get_state_for_room(
-                destination=dest,
-                room_id=room_id,
-                event_id=e_id,
-            )
-            state_events.update({s.event_id: s for s in state})
-            events_to_state[e_id] = state
+            return
 
-        required_auth = {
-            a_id
-            for event in events + list(state_events.values())
-            for a_id in event.auth_event_ids()
-        }
-        auth_events = await self.store.get_events(required_auth, allow_rejected=True)
-        auth_events.update(
-            {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
-        )
-
-        ev_infos = []
-
-        # Step 1: persist the events in the chunk we fetched state for (i.e.
-        # the backwards extremities), with custom auth events and state
-        for e_id in events_to_state:
-            # For paranoia we ensure that these events are marked as
-            # non-outliers
-            ev = event_map[e_id]
-            assert not ev.internal_metadata.is_outlier()
-
-            ev_infos.append(
-                _NewEventInfo(
-                    event=ev,
-                    state=events_to_state[e_id],
-                    claimed_auth_event_map={
-                        (
-                            auth_events[a_id].type,
-                            auth_events[a_id].state_key,
-                        ): auth_events[a_id]
-                        for a_id in ev.auth_event_ids()
-                        if a_id in auth_events
-                    },
+        # if there are any events in the wrong room, the remote server is buggy and
+        # should not be trusted.
+        for ev in events:
+            if ev.room_id != room_id:
+                raise InvalidResponseError(
+                    f"Remote server {dest} returned event {ev.event_id} which is in "
+                    f"room {ev.room_id}, when we were backfilling in {room_id}"
                 )
-            )
-
-        if ev_infos:
-            await self._auth_and_persist_events(
-                dest, room_id, ev_infos, backfilled=True
-            )
-
-        # Step 2: Persist the rest of the events in the chunk one by one
-        events.sort(key=lambda e: e.depth)
-
-        for event in events:
-            if event in events_to_state:
-                continue
-
-            # For paranoia we ensure that these events are marked as
-            # non-outliers
-            assert not event.internal_metadata.is_outlier()
-
-            context = await self.state_handler.compute_event_context(event)
-
-            # We store these one at a time since each event depends on the
-            # previous to work out the state.
-            # TODO: We can probably do something more clever here.
-            await self._auth_and_persist_event(dest, event, context, backfilled=True)
 
-        return events
+        await self._process_pulled_events(dest, events, backfilled=True)
 
     async def maybe_backfill(
         self, room_id: str, current_depth: int, limit: int
@@ -1197,7 +1002,7 @@ class FederationHandler(BaseHandler):
                     # appropriate stuff.
                     # TODO: We can probably do something more intelligent here.
                     return True
-                except SynapseError as e:
+                except (SynapseError, InvalidResponseError) as e:
                     logger.info("Failed to backfill from %s because %s", dom, e)
                     continue
                 except HttpResponseException as e:
@@ -1351,7 +1156,7 @@ class FederationHandler(BaseHandler):
                 else:
                     logger.info("Missing auth event %s", auth_event_id)
 
-            event_infos.append(_NewEventInfo(event, None, auth))
+            event_infos.append(_NewEventInfo(event, auth))
 
         if event_infos:
             await self._auth_and_persist_events(
@@ -1361,7 +1166,7 @@ class FederationHandler(BaseHandler):
             )
 
     async def _process_pulled_events(
-        self, origin: str, events: Iterable[EventBase]
+        self, origin: str, events: Iterable[EventBase], backfilled: bool
     ) -> None:
         """Process a batch of events we have pulled from a remote server
 
@@ -1373,6 +1178,8 @@ class FederationHandler(BaseHandler):
         Params:
             origin: The server we received these events from
             events: The received events.
+            backfilled: True if this is part of a historical batch of events (inhibits
+                notification to clients, and validation of device keys.)
         """
 
         # We want to sort these by depth so we process them and
@@ -1381,9 +1188,11 @@ class FederationHandler(BaseHandler):
 
         for ev in sorted_events:
             with nested_logging_context(ev.event_id):
-                await self._process_pulled_event(origin, ev)
+                await self._process_pulled_event(origin, ev, backfilled=backfilled)
 
-    async def _process_pulled_event(self, origin: str, event: EventBase) -> None:
+    async def _process_pulled_event(
+        self, origin: str, event: EventBase, backfilled: bool
+    ) -> None:
         """Process a single event that we have pulled from a remote server
 
         Pulls in any events required to auth the event, persists the received event,
@@ -1400,6 +1209,8 @@ class FederationHandler(BaseHandler):
         Params:
             origin: The server we received this event from
             events: The received event
+            backfilled: True if this is part of a historical batch of events (inhibits
+                notification to clients, and validation of device keys.)
         """
         logger.info("Processing pulled event %s", event)
 
@@ -1428,7 +1239,9 @@ class FederationHandler(BaseHandler):
 
         try:
             state = await self._resolve_state_at_missing_prevs(origin, event)
-            await self._process_received_pdu(origin, event, state=state)
+            await self._process_received_pdu(
+                origin, event, state=state, backfilled=backfilled
+            )
         except FederationError as e:
             if e.code == 403:
                 logger.warning("Pulled event %s failed history check.", event_id)
@@ -2451,7 +2264,6 @@ class FederationHandler(BaseHandler):
         origin: str,
         room_id: str,
         event_infos: Collection[_NewEventInfo],
-        backfilled: bool = False,
     ) -> None:
         """Creates the appropriate contexts and persists events. The events
         should not depend on one another, e.g. this should be used to persist
@@ -2467,16 +2279,12 @@ class FederationHandler(BaseHandler):
         async def prep(ev_info: _NewEventInfo):
             event = ev_info.event
             with nested_logging_context(suffix=event.event_id):
-                res = await self.state_handler.compute_event_context(
-                    event, old_state=ev_info.state
-                )
+                res = await self.state_handler.compute_event_context(event)
                 res = await self._check_event_auth(
                     origin,
                     event,
                     res,
-                    state=ev_info.state,
                     claimed_auth_event_map=ev_info.claimed_auth_event_map,
-                    backfilled=backfilled,
                 )
             return res
 
@@ -2493,7 +2301,6 @@ class FederationHandler(BaseHandler):
                 (ev_info.event, context)
                 for ev_info, context in zip(event_infos, contexts)
             ],
-            backfilled=backfilled,
         )
 
     async def _persist_auth_tree(