diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 567afc910f..e7b9f15e13 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -897,10 +897,24 @@ class FederationEventHandler:
logger.debug("We are also missing %i auth events", len(missing_auth_events))
missing_events = missing_desired_events | missing_auth_events
- logger.debug("Fetching %i events from remote", len(missing_events))
- await self._get_events_and_persist(
- destination=destination, room_id=room_id, event_ids=missing_events
- )
+
+ # Making an individual request for each of 1000s of events has a lot of
+ # overhead. On the other hand, we don't really want to fetch all of the events
+ # if we already have most of them.
+ #
+ # As an arbitrary heuristic, if we are missing more than 10% of the events, then
+ # we fetch the whole state.
+ #
+ # TODO: might it be better to have an API which lets us do an aggregate event
+ # request
+ if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
+ logger.debug("Requesting complete state from remote")
+ await self._get_state_and_persist(destination, room_id, event_id)
+ else:
+ logger.debug("Fetching %i events from remote", len(missing_events))
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, event_ids=missing_events
+ )
# we need to make sure we re-load from the database to get the rejected
# state correct.
@@ -959,6 +973,27 @@ class FederationEventHandler:
return remote_state
+ async def _get_state_and_persist(
+ self, destination: str, room_id: str, event_id: str
+ ) -> None:
+ """Get the complete room state at a given event, and persist any new events
+ as outliers"""
+ room_version = await self._store.get_room_version(room_id)
+ auth_events, state_events = await self._federation_client.get_room_state(
+ destination, room_id, event_id=event_id, room_version=room_version
+ )
+ logger.info("/state returned %i events", len(auth_events) + len(state_events))
+
+ await self._auth_and_persist_outliers(
+ room_id, itertools.chain(auth_events, state_events)
+ )
+
+ # we also need the event itself.
+ if not await self._store.have_seen_event(room_id, event_id):
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, event_ids=(event_id,)
+ )
+
async def _process_received_pdu(
self,
origin: str,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 59454a47df..a60e3f4fdd 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -22,7 +22,6 @@ from typing import (
Dict,
Iterable,
List,
- NoReturn,
Optional,
Set,
Tuple,
@@ -1330,10 +1329,9 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
- # this only exists for the benefit of the @cachedList descriptor on
- # _have_seen_events_dict
- raise NotImplementedError()
+ async def have_seen_event(self, room_id: str, event_id: str) -> bool:
+ res = await self._have_seen_events_dict(((room_id, event_id),))
+ return res[(room_id, event_id)]
def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str
|