diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b06387051c..65778fd4ee 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -314,6 +314,40 @@ class FederationClient(FederationBase):
Deferred: Results in a list of PDUs.
"""
+ try:
+ # First we try and ask for just the IDs, as thats far quicker if
+ # we have most of the state and auth_chain already.
+ # However, this may 404 if the other side has an old synapse.
+ result = yield self.transport_layer.get_room_state_ids(
+ destination, room_id, event_id=event_id,
+ )
+
+ state_event_ids = result["pdu_ids"]
+ auth_event_ids = result.get("auth_chain_ids", [])
+
+ fetched_events, failed_to_fetch = yield self.get_events(
+ [destination], room_id, set(state_event_ids + auth_event_ids)
+ )
+
+ if failed_to_fetch:
+ logger.warn("Failed to get %r", failed_to_fetch)
+
+ event_map = {
+ ev.event_id: ev for ev in fetched_events
+ }
+
+ pdus = [event_map[e_id] for e_id in state_event_ids]
+ auth_chain = [event_map[e_id] for e_id in auth_event_ids]
+
+ auth_chain.sort(key=lambda e: e.depth)
+
+ defer.returnValue((pdus, auth_chain))
+ except HttpResponseException as e:
+ if e.code == 400 or e.code == 404:
+ logger.info("Failed to use get_room_state_ids API, falling back")
+ else:
+ raise e
+
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
)
@@ -327,12 +361,26 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", [])
]
+ seen_events = yield self.store.get_events([
+ ev.event_id for ev in itertools.chain(pdus, auth_chain)
+ ])
+
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
- destination, pdus, outlier=True
+ destination,
+ [p for p in pdus if p.event_id not in seen_events],
+ outlier=True
+ )
+ signed_pdus.extend(
+ seen_events[p.event_id] for p in pdus if p.event_id in seen_events
)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True
+ destination,
+ [p for p in auth_chain if p.event_id not in seen_events],
+ outlier=True
+ )
+ signed_auth.extend(
+ seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
)
signed_auth.sort(key=lambda e: e.depth)
@@ -340,6 +388,67 @@ class FederationClient(FederationBase):
defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
+ def get_events(self, destinations, room_id, event_ids, return_local=True):
+ """Fetch events from some remote destinations, checking if we already
+ have them.
+
+ Args:
+ destinations (list)
+ room_id (str)
+ event_ids (list)
+ return_local (bool): Whether to include events we already have in
+ the DB in the returned list of events
+
+ Returns:
+ Deferred: A deferred resolving to a 2-tuple where the first is a list of
+ events and the second is a list of event ids that we failed to fetch.
+ """
+ if return_local:
+ seen_events = yield self.store.get_events(event_ids)
+ signed_events = seen_events.values()
+ else:
+ seen_events = yield self.store.have_events(event_ids)
+ signed_events = []
+
+ failed_to_fetch = set()
+
+ missing_events = set(event_ids)
+ for k in seen_events:
+ missing_events.discard(k)
+
+ if not missing_events:
+ defer.returnValue((signed_events, failed_to_fetch))
+
+ def random_server_list():
+ srvs = list(destinations)
+ random.shuffle(srvs)
+ return srvs
+
+ batch_size = 20
+ missing_events = list(missing_events)
+ for i in xrange(0, len(missing_events), batch_size):
+ batch = set(missing_events[i:i + batch_size])
+
+ deferreds = [
+ self.get_pdu(
+ destinations=random_server_list(),
+ event_id=e_id,
+ )
+ for e_id in batch
+ ]
+
+ res = yield defer.DeferredList(deferreds, consumeErrors=True)
+ for success, result in res:
+ if success:
+ signed_events.append(result)
+ batch.discard(result.event_id)
+
+ # We removed all events we successfully fetched from `batch`
+ failed_to_fetch.update(batch)
+
+ defer.returnValue((signed_events, failed_to_fetch))
+
+ @defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
|