summary refs log tree commit diff
path: root/synapse/federation/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r--synapse/federation/federation_client.py209
1 files changed, 189 insertions, 20 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 37ee469fa2..9ba3151713 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -24,6 +24,7 @@ from synapse.api.errors import (
     CodeMessageException, HttpResponseException, SynapseError,
 )
 from synapse.util import unwrapFirstError
+from synapse.util.async import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.events import FrozenEvent
@@ -50,7 +51,33 @@ sent_edus_counter = metrics.register_counter("sent_edus")
 sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
 
 
+PDU_RETRY_TIME_MS = 1 * 60 * 1000
+
+
 class FederationClient(FederationBase):
+    def __init__(self, hs):
+        super(FederationClient, self).__init__(hs)
+
+        self.pdu_destination_tried = {}
+        self._clock.looping_call(
+            self._clear_tried_cache, 60 * 1000,
+        )
+
+    def _clear_tried_cache(self):
+        """Clear pdu_destination_tried cache"""
+        now = self._clock.time_msec()
+
+        old_dict = self.pdu_destination_tried
+        self.pdu_destination_tried = {}
+
+        for event_id, destination_dict in old_dict.items():
+            destination_dict = {
+                dest: time
+                for dest, time in destination_dict.items()
+                if time + PDU_RETRY_TIME_MS > now
+            }
+            if destination_dict:
+                self.pdu_destination_tried[event_id] = destination_dict
 
     def start_get_pdu_cache(self):
         self._get_pdu_cache = ExpiringCache(
@@ -233,12 +260,19 @@ class FederationClient(FederationBase):
         # TODO: Rate limit the number of times we try and get the same event.
 
         if self._get_pdu_cache:
-            e = self._get_pdu_cache.get(event_id)
-            if e:
-                defer.returnValue(e)
+            ev = self._get_pdu_cache.get(event_id)
+            if ev:
+                defer.returnValue(ev)
+
+        pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
 
         pdu = None
         for destination in destinations:
+            now = self._clock.time_msec()
+            last_attempt = pdu_attempts.get(destination, 0)
+            if last_attempt + PDU_RETRY_TIME_MS > now:
+                continue
+
             try:
                 limiter = yield get_retry_limiter(
                     destination,
@@ -266,25 +300,19 @@ class FederationClient(FederationBase):
 
                         break
 
-            except SynapseError:
-                logger.info(
-                    "Failed to get PDU %s from %s because %s",
-                    event_id, destination, e,
-                )
-                continue
-            except CodeMessageException as e:
-                if 400 <= e.code < 500:
-                    raise
+                pdu_attempts[destination] = now
 
+            except SynapseError as e:
                 logger.info(
                     "Failed to get PDU %s from %s because %s",
                     event_id, destination, e,
                 )
-                continue
             except NotRetryingDestination as e:
                 logger.info(e.message)
                 continue
             except Exception as e:
+                pdu_attempts[destination] = now
+
                 logger.info(
                     "Failed to get PDU %s from %s because %s",
                     event_id, destination, e,
@@ -311,6 +339,42 @@ class FederationClient(FederationBase):
             Deferred: Results in a list of PDUs.
         """
 
+        try:
+            # First we try and ask for just the IDs, as thats far quicker if
+            # we have most of the state and auth_chain already.
+            # However, this may 404 if the other side has an old synapse.
+            result = yield self.transport_layer.get_room_state_ids(
+                destination, room_id, event_id=event_id,
+            )
+
+            state_event_ids = result["pdu_ids"]
+            auth_event_ids = result.get("auth_chain_ids", [])
+
+            fetched_events, failed_to_fetch = yield self.get_events(
+                [destination], room_id, set(state_event_ids + auth_event_ids)
+            )
+
+            if failed_to_fetch:
+                logger.warn("Failed to get %r", failed_to_fetch)
+
+            event_map = {
+                ev.event_id: ev for ev in fetched_events
+            }
+
+            pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
+            auth_chain = [
+                event_map[e_id] for e_id in auth_event_ids if e_id in event_map
+            ]
+
+            auth_chain.sort(key=lambda e: e.depth)
+
+            defer.returnValue((pdus, auth_chain))
+        except HttpResponseException as e:
+            if e.code == 400 or e.code == 404:
+                logger.info("Failed to use get_room_state_ids API, falling back")
+            else:
+                raise e
+
         result = yield self.transport_layer.get_room_state(
             destination, room_id, event_id=event_id,
         )
@@ -324,12 +388,26 @@ class FederationClient(FederationBase):
             for p in result.get("auth_chain", [])
         ]
 
+        seen_events = yield self.store.get_events([
+            ev.event_id for ev in itertools.chain(pdus, auth_chain)
+        ])
+
         signed_pdus = yield self._check_sigs_and_hash_and_fetch(
-            destination, pdus, outlier=True
+            destination,
+            [p for p in pdus if p.event_id not in seen_events],
+            outlier=True
+        )
+        signed_pdus.extend(
+            seen_events[p.event_id] for p in pdus if p.event_id in seen_events
         )
 
         signed_auth = yield self._check_sigs_and_hash_and_fetch(
-            destination, auth_chain, outlier=True
+            destination,
+            [p for p in auth_chain if p.event_id not in seen_events],
+            outlier=True
+        )
+        signed_auth.extend(
+            seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
         )
 
         signed_auth.sort(key=lambda e: e.depth)
@@ -337,6 +415,67 @@ class FederationClient(FederationBase):
         defer.returnValue((signed_pdus, signed_auth))
 
     @defer.inlineCallbacks
+    def get_events(self, destinations, room_id, event_ids, return_local=True):
+        """Fetch events from some remote destinations, checking if we already
+        have them.
+
+        Args:
+            destinations (list)
+            room_id (str)
+            event_ids (list)
+            return_local (bool): Whether to include events we already have in
+                the DB in the returned list of events
+
+        Returns:
+            Deferred: A deferred resolving to a 2-tuple where the first is a list of
+            events and the second is a list of event ids that we failed to fetch.
+        """
+        if return_local:
+            seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
+            signed_events = seen_events.values()
+        else:
+            seen_events = yield self.store.have_events(event_ids)
+            signed_events = []
+
+        failed_to_fetch = set()
+
+        missing_events = set(event_ids)
+        for k in seen_events:
+            missing_events.discard(k)
+
+        if not missing_events:
+            defer.returnValue((signed_events, failed_to_fetch))
+
+        def random_server_list():
+            srvs = list(destinations)
+            random.shuffle(srvs)
+            return srvs
+
+        batch_size = 20
+        missing_events = list(missing_events)
+        for i in xrange(0, len(missing_events), batch_size):
+            batch = set(missing_events[i:i + batch_size])
+
+            deferreds = [
+                self.get_pdu(
+                    destinations=random_server_list(),
+                    event_id=e_id,
+                )
+                for e_id in batch
+            ]
+
+            res = yield defer.DeferredList(deferreds, consumeErrors=True)
+            for success, result in res:
+                if success:
+                    signed_events.append(result)
+                    batch.discard(result.event_id)
+
+            # We removed all events we successfully fetched from `batch`
+            failed_to_fetch.update(batch)
+
+        defer.returnValue((signed_events, failed_to_fetch))
+
+    @defer.inlineCallbacks
     @log_function
     def get_event_auth(self, destination, room_id, event_id):
         res = yield self.transport_layer.get_event_auth(
@@ -411,14 +550,19 @@ class FederationClient(FederationBase):
                     (destination, self.event_from_pdu_json(pdu_dict))
                 )
                 break
-            except CodeMessageException:
-                raise
+            except CodeMessageException as e:
+                if not 500 <= e.code < 600:
+                    raise
+                else:
+                    logger.warn(
+                        "Failed to make_%s via %s: %s",
+                        membership, destination, e.message
+                    )
             except Exception as e:
                 logger.warn(
                     "Failed to make_%s via %s: %s",
                     membership, destination, e.message
                 )
-                raise
 
         raise RuntimeError("Failed to send to any server.")
 
@@ -490,8 +634,14 @@ class FederationClient(FederationBase):
                     "auth_chain": signed_auth,
                     "origin": destination,
                 })
-            except CodeMessageException:
-                raise
+            except CodeMessageException as e:
+                if not 500 <= e.code < 600:
+                    raise
+                else:
+                    logger.exception(
+                        "Failed to send_join via %s: %s",
+                        destination, e.message
+                    )
             except Exception as e:
                 logger.exception(
                     "Failed to send_join via %s: %s",
@@ -551,6 +701,25 @@ class FederationClient(FederationBase):
         raise RuntimeError("Failed to send to any server.")
 
     @defer.inlineCallbacks
+    def get_public_rooms(self, destinations):
+        results_by_server = {}
+
+        @defer.inlineCallbacks
+        def _get_result(s):
+            if s == self.server_name:
+                defer.returnValue()
+
+            try:
+                result = yield self.transport_layer.get_public_rooms(s)
+                results_by_server[s] = result
+            except:
+                logger.exception("Error getting room list from server %r", s)
+
+        yield concurrently_execute(_get_result, destinations, 3)
+
+        defer.returnValue(results_by_server)
+
+    @defer.inlineCallbacks
     def query_auth(self, destination, room_id, event_id, local_auth):
         """
         Params: