diff options
author | David Baker <dave@matrix.org> | 2016-08-11 14:09:13 +0100 |
---|---|---|
committer | David Baker <dave@matrix.org> | 2016-08-11 14:09:13 +0100 |
commit | b4ecf0b886c67437901e0af457c5f801ebde9a72 (patch) | |
tree | ef66b0684edcfeb4ad68d20375641f4654393f44 /synapse/federation/federation_client.py | |
parent | Include the ts the notif was received at (diff) | |
parent | Merge pull request #1003 from matrix-org/erikj/redaction_prev_content (diff) | |
download | synapse-b4ecf0b886c67437901e0af457c5f801ebde9a72.tar.xz |
Merge remote-tracking branch 'origin/develop' into dbkr/notifications_api
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r-- | synapse/federation/federation_client.py | 209 |
1 files changed, 189 insertions, 20 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 37ee469fa2..9ba3151713 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -24,6 +24,7 @@ from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, ) from synapse.util import unwrapFirstError +from synapse.util.async import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent @@ -50,7 +51,33 @@ sent_edus_counter = metrics.register_counter("sent_edus") sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) +PDU_RETRY_TIME_MS = 1 * 60 * 1000 + + class FederationClient(FederationBase): + def __init__(self, hs): + super(FederationClient, self).__init__(hs) + + self.pdu_destination_tried = {} + self._clock.looping_call( + self._clear_tried_cache, 60 * 1000, + ) + + def _clear_tried_cache(self): + """Clear pdu_destination_tried cache""" + now = self._clock.time_msec() + + old_dict = self.pdu_destination_tried + self.pdu_destination_tried = {} + + for event_id, destination_dict in old_dict.items(): + destination_dict = { + dest: time + for dest, time in destination_dict.items() + if time + PDU_RETRY_TIME_MS > now + } + if destination_dict: + self.pdu_destination_tried[event_id] = destination_dict def start_get_pdu_cache(self): self._get_pdu_cache = ExpiringCache( @@ -233,12 +260,19 @@ class FederationClient(FederationBase): # TODO: Rate limit the number of times we try and get the same event. if self._get_pdu_cache: - e = self._get_pdu_cache.get(event_id) - if e: - defer.returnValue(e) + ev = self._get_pdu_cache.get(event_id) + if ev: + defer.returnValue(ev) + + pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) pdu = None for destination in destinations: + now = self._clock.time_msec() + last_attempt = pdu_attempts.get(destination, 0) + if last_attempt + PDU_RETRY_TIME_MS > now: + continue + try: limiter = yield get_retry_limiter( destination, @@ -266,25 +300,19 @@ class FederationClient(FederationBase): break - except SynapseError: - logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, - ) - continue - except CodeMessageException as e: - if 400 <= e.code < 500: - raise + pdu_attempts[destination] = now + except SynapseError as e: logger.info( "Failed to get PDU %s from %s because %s", event_id, destination, e, ) - continue except NotRetryingDestination as e: logger.info(e.message) continue except Exception as e: + pdu_attempts[destination] = now + logger.info( "Failed to get PDU %s from %s because %s", event_id, destination, e, @@ -311,6 +339,42 @@ class FederationClient(FederationBase): Deferred: Results in a list of PDUs. """ + try: + # First we try and ask for just the IDs, as thats far quicker if + # we have most of the state and auth_chain already. + # However, this may 404 if the other side has an old synapse. + result = yield self.transport_layer.get_room_state_ids( + destination, room_id, event_id=event_id, + ) + + state_event_ids = result["pdu_ids"] + auth_event_ids = result.get("auth_chain_ids", []) + + fetched_events, failed_to_fetch = yield self.get_events( + [destination], room_id, set(state_event_ids + auth_event_ids) + ) + + if failed_to_fetch: + logger.warn("Failed to get %r", failed_to_fetch) + + event_map = { + ev.event_id: ev for ev in fetched_events + } + + pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] + auth_chain = [ + event_map[e_id] for e_id in auth_event_ids if e_id in event_map + ] + + auth_chain.sort(key=lambda e: e.depth) + + defer.returnValue((pdus, auth_chain)) + except HttpResponseException as e: + if e.code == 400 or e.code == 404: + logger.info("Failed to use get_room_state_ids API, falling back") + else: + raise e + result = yield self.transport_layer.get_room_state( destination, room_id, event_id=event_id, ) @@ -324,12 +388,26 @@ class FederationClient(FederationBase): for p in result.get("auth_chain", []) ] + seen_events = yield self.store.get_events([ + ev.event_id for ev in itertools.chain(pdus, auth_chain) + ]) + signed_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, pdus, outlier=True + destination, + [p for p in pdus if p.event_id not in seen_events], + outlier=True + ) + signed_pdus.extend( + seen_events[p.event_id] for p in pdus if p.event_id in seen_events ) signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True + destination, + [p for p in auth_chain if p.event_id not in seen_events], + outlier=True + ) + signed_auth.extend( + seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events ) signed_auth.sort(key=lambda e: e.depth) @@ -337,6 +415,67 @@ class FederationClient(FederationBase): defer.returnValue((signed_pdus, signed_auth)) @defer.inlineCallbacks + def get_events(self, destinations, room_id, event_ids, return_local=True): + """Fetch events from some remote destinations, checking if we already + have them. + + Args: + destinations (list) + room_id (str) + event_ids (list) + return_local (bool): Whether to include events we already have in + the DB in the returned list of events + + Returns: + Deferred: A deferred resolving to a 2-tuple where the first is a list of + events and the second is a list of event ids that we failed to fetch. + """ + if return_local: + seen_events = yield self.store.get_events(event_ids, allow_rejected=True) + signed_events = seen_events.values() + else: + seen_events = yield self.store.have_events(event_ids) + signed_events = [] + + failed_to_fetch = set() + + missing_events = set(event_ids) + for k in seen_events: + missing_events.discard(k) + + if not missing_events: + defer.returnValue((signed_events, failed_to_fetch)) + + def random_server_list(): + srvs = list(destinations) + random.shuffle(srvs) + return srvs + + batch_size = 20 + missing_events = list(missing_events) + for i in xrange(0, len(missing_events), batch_size): + batch = set(missing_events[i:i + batch_size]) + + deferreds = [ + self.get_pdu( + destinations=random_server_list(), + event_id=e_id, + ) + for e_id in batch + ] + + res = yield defer.DeferredList(deferreds, consumeErrors=True) + for success, result in res: + if success: + signed_events.append(result) + batch.discard(result.event_id) + + # We removed all events we successfully fetched from `batch` + failed_to_fetch.update(batch) + + defer.returnValue((signed_events, failed_to_fetch)) + + @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): res = yield self.transport_layer.get_event_auth( @@ -411,14 +550,19 @@ class FederationClient(FederationBase): (destination, self.event_from_pdu_json(pdu_dict)) ) break - except CodeMessageException: - raise + except CodeMessageException as e: + if not 500 <= e.code < 600: + raise + else: + logger.warn( + "Failed to make_%s via %s: %s", + membership, destination, e.message + ) except Exception as e: logger.warn( "Failed to make_%s via %s: %s", membership, destination, e.message ) - raise raise RuntimeError("Failed to send to any server.") @@ -490,8 +634,14 @@ class FederationClient(FederationBase): "auth_chain": signed_auth, "origin": destination, }) - except CodeMessageException: - raise + except CodeMessageException as e: + if not 500 <= e.code < 600: + raise + else: + logger.exception( + "Failed to send_join via %s: %s", + destination, e.message + ) except Exception as e: logger.exception( "Failed to send_join via %s: %s", @@ -551,6 +701,25 @@ class FederationClient(FederationBase): raise RuntimeError("Failed to send to any server.") @defer.inlineCallbacks + def get_public_rooms(self, destinations): + results_by_server = {} + + @defer.inlineCallbacks + def _get_result(s): + if s == self.server_name: + defer.returnValue() + + try: + result = yield self.transport_layer.get_public_rooms(s) + results_by_server[s] = result + except: + logger.exception("Error getting room list from server %r", s) + + yield concurrently_execute(_get_result, destinations, 3) + + defer.returnValue(results_by_server) + + @defer.inlineCallbacks def query_auth(self, destination, room_id, event_id, local_auth): """ Params: |