diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b06387051c..03f6133e61 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -314,9 +314,32 @@ class FederationClient(FederationBase):
Deferred: Results in a list of PDUs.
"""
- result = yield self.transport_layer.get_room_state(
- destination, room_id, event_id=event_id,
- )
+ try:
+ # First we try and ask for just the IDs, as thats far quicker if
+ # we have most of the state and auth_chain already.
+ # However, this may 404 if the other side has an old synapse.
+ result = yield self.transport_layer.get_room_state_ids(
+ destination, room_id, event_id=event_id,
+ )
+
+ state_event_ids = result["pdus"]
+ auth_event_ids = result.get("auth_chain", [])
+
+ event_map, _failed_to_fetch = yield self.get_events(
+ [destination], room_id, set(state_event_ids + auth_event_ids)
+ )
+
+ pdus = [event_map[e_id] for e_id in state_event_ids]
+ auth_chain = [event_map[e_id] for e_id in auth_event_ids]
+
+ auth_chain.sort(key=lambda e: e.depth)
+
+ defer.returnValue((pdus, auth_chain))
+ except HttpResponseException as e:
+ if e.code == 404:
+ logger.info("Failed to use get_room_state_ids API, falling back")
+ else:
+ raise e
pdus = [
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
@@ -340,6 +363,50 @@ class FederationClient(FederationBase):
defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
+ def get_events(self, destinations, room_id, event_ids, return_local=True):
+ if return_local:
+ seen_events = yield self.store.get_events(event_ids)
+ signed_events = seen_events.values()
+ else:
+ seen_events = yield self.store.have_events(event_ids)
+ signed_events = []
+
+ failed_to_fetch = []
+
+ missing_events = set(event_ids)
+ for k in seen_events:
+ missing_events.discard(k)
+
+ if not missing_events:
+ defer.returnValue((signed_events, failed_to_fetch))
+
+ def random_server_list():
+ srvs = list(destinations)
+ random.shuffle(srvs)
+ return srvs
+
+ batch_size = 20
+ for i in xrange(0, len(missing_events), batch_size):
+ batch = missing_events[i:i + batch_size]
+
+ deferreds = [
+ self.get_pdu(
+ destinations=random_server_list(),
+ event_id=e_id,
+ ).addBoth(lambda r, e: (r, e), e_id)
+ for e_id in batch
+ ]
+
+ res = yield defer.DeferredList(deferreds, consumeErrors=True)
+ for (result, val), (e_id, _) in res:
+ if result and val:
+ signed_events.append(val)
+ else:
+ failed_to_fetch.add(e_id)
+
+ defer.returnValue((signed_events, failed_to_fetch))
+
+ @defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
|