summary refs log tree commit diff
path: root/synapse/federation/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r--synapse/federation/federation_client.py44
1 files changed, 15 insertions, 29 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 709449c9e3..73e1dda6a3 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,8 +18,6 @@ import copy
 import itertools
 import logging
 
-from six.moves import range
-
 from prometheus_client import Counter
 
 from twisted.internet import defer
@@ -41,7 +39,7 @@ from synapse.events import builder, room_version_to_event_format
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.utils import log_function
-from synapse.util import unwrapFirstError
+from synapse.util import batch_iter, unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -331,10 +329,12 @@ class FederationClient(FederationBase):
         state_event_ids = result["pdu_ids"]
         auth_event_ids = result.get("auth_chain_ids", [])
 
-        fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
-            destination, room_id, set(state_event_ids + auth_event_ids)
+        desired_events = set(state_event_ids + auth_event_ids)
+        event_map = yield self.get_events_from_store_or_dest(
+            destination, room_id, desired_events
         )
 
+        failed_to_fetch = desired_events - event_map.keys()
         if failed_to_fetch:
             logger.warning(
                 "Failed to fetch missing state/auth events for %s: %s",
@@ -342,8 +342,6 @@ class FederationClient(FederationBase):
                 failed_to_fetch,
             )
 
-        event_map = {ev.event_id: ev for ev in fetched_events}
-
         pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
         auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
 
@@ -358,23 +356,18 @@ class FederationClient(FederationBase):
         Args:
             destination (str)
             room_id (str)
-            event_ids (list)
+            event_ids (Iterable[str])
 
         Returns:
-            Deferred: A deferred resolving to a 2-tuple where the first is a list of
-            events and the second is a list of event ids that we failed to fetch.
+            Deferred[dict[str, EventBase]]: A deferred resolving to a map
+            from event_id to event
         """
-        seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
-        signed_events = list(seen_events.values())
-
-        failed_to_fetch = set()
+        fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
 
-        missing_events = set(event_ids)
-        for k in seen_events:
-            missing_events.discard(k)
+        missing_events = set(event_ids) - fetched_events.keys()
 
         if not missing_events:
-            return signed_events, failed_to_fetch
+            return fetched_events
 
         logger.debug(
             "Fetching unknown state/auth events %s for room %s",
@@ -384,11 +377,8 @@ class FederationClient(FederationBase):
 
         room_version = yield self.store.get_room_version(room_id)
 
-        batch_size = 20
-        missing_events = list(missing_events)
-        for i in range(0, len(missing_events), batch_size):
-            batch = set(missing_events[i : i + batch_size])
-
+        # XXX 20 requests at once? really?
+        for batch in batch_iter(missing_events, 20):
             deferreds = [
                 run_in_background(
                     self.get_pdu,
@@ -404,13 +394,9 @@ class FederationClient(FederationBase):
             )
             for success, result in res:
                 if success and result:
-                    signed_events.append(result)
-                    batch.discard(result.event_id)
-
-            # We removed all events we successfully fetched from `batch`
-            failed_to_fetch.update(batch)
+                    fetched_events[result.event_id] = result
 
-        return signed_events, failed_to_fetch
+        return fetched_events
 
     @defer.inlineCallbacks
     @log_function