summary refs log tree commit diff
path: root/synapse/federation/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r--synapse/federation/federation_client.py289
1 files changed, 203 insertions, 86 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py

index 83c1f46586..b255709165 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -18,7 +18,6 @@ from twisted.internet import defer from .federation_base import FederationBase from synapse.api.constants import Membership -from .units import Edu from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, @@ -26,7 +25,9 @@ from synapse.api.errors import ( from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.events import FrozenEvent +from synapse.types import get_domain_from_id import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination @@ -43,14 +44,37 @@ logger = logging.getLogger(__name__) # synapse.federation.federation_client is a silly name metrics = synapse.metrics.get_metrics_for("synapse.federation.client") -sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations") +sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) -sent_edus_counter = metrics.register_counter("sent_edus") -sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) +PDU_RETRY_TIME_MS = 1 * 60 * 1000 class FederationClient(FederationBase): + def __init__(self, hs): + super(FederationClient, self).__init__(hs) + + self.pdu_destination_tried = {} + self._clock.looping_call( + self._clear_tried_cache, 60 * 1000, + ) + self.state = hs.get_state_handler() + + def _clear_tried_cache(self): + """Clear pdu_destination_tried cache""" + now = self._clock.time_msec() + + old_dict = self.pdu_destination_tried + self.pdu_destination_tried = {} + + for event_id, destination_dict in old_dict.items(): + destination_dict = { + dest: time + for dest, time in destination_dict.items() + if time + PDU_RETRY_TIME_MS > now + } + if destination_dict: + self.pdu_destination_tried[event_id] = destination_dict def start_get_pdu_cache(self): self._get_pdu_cache = ExpiringCache( @@ -64,55 +88,6 @@ class FederationClient(FederationBase): self._get_pdu_cache.start() @log_function - def send_pdu(self, pdu, destinations): - """Informs the replication layer about a new PDU generated within the - home server that should be transmitted to others. - - TODO: Figure out when we should actually resolve the deferred. - - Args: - pdu (Pdu): The new Pdu. - - Returns: - Deferred: Completes when we have successfully processed the PDU - and replicated it to any interested remote home servers. - """ - order = self._order - self._order += 1 - - sent_pdus_destination_dist.inc_by(len(destinations)) - - logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) - - # TODO, add errback, etc. - self._transaction_queue.enqueue_pdu(pdu, destinations, order) - - logger.debug( - "[%s] transaction_layer.enqueue_pdu... done", - pdu.event_id - ) - - @log_function - def send_edu(self, destination, edu_type, content): - edu = Edu( - origin=self.server_name, - destination=destination, - edu_type=edu_type, - content=content, - ) - - sent_edus_counter.inc() - - # TODO, add errback, etc. - self._transaction_queue.enqueue_edu(edu) - return defer.succeed(None) - - @log_function - def send_failure(self, failure, destination): - self._transaction_queue.enqueue_failure(failure, destination) - return defer.succeed(None) - - @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail=False): """Sends a federation Query to a remote homeserver of the given type @@ -136,7 +111,7 @@ class FederationClient(FederationBase): ) @log_function - def query_client_keys(self, destination, content): + def query_client_keys(self, destination, content, timeout): """Query device keys for a device hosted on a remote server. Args: @@ -148,10 +123,12 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_device_keys") - return self.transport_layer.query_client_keys(destination, content) + return self.transport_layer.query_client_keys( + destination, content, timeout + ) @log_function - def claim_client_keys(self, destination, content): + def claim_client_keys(self, destination, content, timeout): """Claims one-time keys for a device hosted on a remote server. Args: @@ -163,7 +140,9 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_one_time_keys") - return self.transport_layer.claim_client_keys(destination, content) + return self.transport_layer.claim_client_keys( + destination, content, timeout + ) @defer.inlineCallbacks @log_function @@ -198,10 +177,10 @@ class FederationClient(FederationBase): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield defer.gatherResults( + pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( self._check_sigs_and_hashes(pdus), consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(pdus) @@ -233,12 +212,19 @@ class FederationClient(FederationBase): # TODO: Rate limit the number of times we try and get the same event. if self._get_pdu_cache: - e = self._get_pdu_cache.get(event_id) - if e: - defer.returnValue(e) + ev = self._get_pdu_cache.get(event_id) + if ev: + defer.returnValue(ev) - pdu = None + pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) + + signed_pdu = None for destination in destinations: + now = self._clock.time_msec() + last_attempt = pdu_attempts.get(destination, 0) + if last_attempt + PDU_RETRY_TIME_MS > now: + continue + try: limiter = yield get_retry_limiter( destination, @@ -262,39 +248,33 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - pdu = yield self._check_sigs_and_hashes([pdu])[0] + signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] break - except SynapseError: - logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, - ) - continue - except CodeMessageException as e: - if 400 <= e.code < 500: - raise + pdu_attempts[destination] = now + except SynapseError as e: logger.info( "Failed to get PDU %s from %s because %s", event_id, destination, e, ) - continue except NotRetryingDestination as e: logger.info(e.message) continue except Exception as e: + pdu_attempts[destination] = now + logger.info( "Failed to get PDU %s from %s because %s", event_id, destination, e, ) continue - if self._get_pdu_cache is not None and pdu: - self._get_pdu_cache[event_id] = pdu + if self._get_pdu_cache is not None and signed_pdu: + self._get_pdu_cache[event_id] = signed_pdu - defer.returnValue(pdu) + defer.returnValue(signed_pdu) @defer.inlineCallbacks @log_function @@ -311,6 +291,42 @@ class FederationClient(FederationBase): Deferred: Results in a list of PDUs. """ + try: + # First we try and ask for just the IDs, as thats far quicker if + # we have most of the state and auth_chain already. + # However, this may 404 if the other side has an old synapse. + result = yield self.transport_layer.get_room_state_ids( + destination, room_id, event_id=event_id, + ) + + state_event_ids = result["pdu_ids"] + auth_event_ids = result.get("auth_chain_ids", []) + + fetched_events, failed_to_fetch = yield self.get_events( + [destination], room_id, set(state_event_ids + auth_event_ids) + ) + + if failed_to_fetch: + logger.warn("Failed to get %r", failed_to_fetch) + + event_map = { + ev.event_id: ev for ev in fetched_events + } + + pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] + auth_chain = [ + event_map[e_id] for e_id in auth_event_ids if e_id in event_map + ] + + auth_chain.sort(key=lambda e: e.depth) + + defer.returnValue((pdus, auth_chain)) + except HttpResponseException as e: + if e.code == 400 or e.code == 404: + logger.info("Failed to use get_room_state_ids API, falling back") + else: + raise e + result = yield self.transport_layer.get_room_state( destination, room_id, event_id=event_id, ) @@ -324,12 +340,26 @@ class FederationClient(FederationBase): for p in result.get("auth_chain", []) ] + seen_events = yield self.store.get_events([ + ev.event_id for ev in itertools.chain(pdus, auth_chain) + ]) + signed_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, pdus, outlier=True + destination, + [p for p in pdus if p.event_id not in seen_events], + outlier=True + ) + signed_pdus.extend( + seen_events[p.event_id] for p in pdus if p.event_id in seen_events ) signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True + destination, + [p for p in auth_chain if p.event_id not in seen_events], + outlier=True + ) + signed_auth.extend( + seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events ) signed_auth.sort(key=lambda e: e.depth) @@ -337,6 +367,69 @@ class FederationClient(FederationBase): defer.returnValue((signed_pdus, signed_auth)) @defer.inlineCallbacks + def get_events(self, destinations, room_id, event_ids, return_local=True): + """Fetch events from some remote destinations, checking if we already + have them. + + Args: + destinations (list) + room_id (str) + event_ids (list) + return_local (bool): Whether to include events we already have in + the DB in the returned list of events + + Returns: + Deferred: A deferred resolving to a 2-tuple where the first is a list of + events and the second is a list of event ids that we failed to fetch. + """ + if return_local: + seen_events = yield self.store.get_events(event_ids, allow_rejected=True) + signed_events = seen_events.values() + else: + seen_events = yield self.store.have_events(event_ids) + signed_events = [] + + failed_to_fetch = set() + + missing_events = set(event_ids) + for k in seen_events: + missing_events.discard(k) + + if not missing_events: + defer.returnValue((signed_events, failed_to_fetch)) + + def random_server_list(): + srvs = list(destinations) + random.shuffle(srvs) + return srvs + + batch_size = 20 + missing_events = list(missing_events) + for i in xrange(0, len(missing_events), batch_size): + batch = set(missing_events[i:i + batch_size]) + + deferreds = [ + preserve_fn(self.get_pdu)( + destinations=random_server_list(), + event_id=e_id, + ) + for e_id in batch + ] + + res = yield preserve_context_over_deferred( + defer.DeferredList(deferreds, consumeErrors=True) + ) + for success, result in res: + if success and result: + signed_events.append(result) + batch.discard(result.event_id) + + # We removed all events we successfully fetched from `batch` + failed_to_fetch.update(batch) + + defer.returnValue((signed_events, failed_to_fetch)) + + @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): res = yield self.transport_layer.get_event_auth( @@ -411,8 +504,14 @@ class FederationClient(FederationBase): (destination, self.event_from_pdu_json(pdu_dict)) ) break - except CodeMessageException: - raise + except CodeMessageException as e: + if not 500 <= e.code < 600: + raise + else: + logger.warn( + "Failed to make_%s via %s: %s", + membership, destination, e.message + ) except Exception as e: logger.warn( "Failed to make_%s via %s: %s", @@ -489,8 +588,14 @@ class FederationClient(FederationBase): "auth_chain": signed_auth, "origin": destination, }) - except CodeMessageException: - raise + except CodeMessageException as e: + if not 500 <= e.code < 600: + raise + else: + logger.exception( + "Failed to send_join via %s: %s", + destination, e.message + ) except Exception as e: logger.exception( "Failed to send_join via %s: %s", @@ -549,6 +654,15 @@ class FederationClient(FederationBase): raise RuntimeError("Failed to send to any server.") + def get_public_rooms(self, destination, limit=None, since_token=None, + search_filter=None): + if destination == self.server_name: + return + + return self.transport_layer.get_public_rooms( + destination, limit, since_token, search_filter + ) + @defer.inlineCallbacks def query_auth(self, destination, room_id, event_id, local_auth): """ @@ -638,7 +752,8 @@ class FederationClient(FederationBase): if len(signed_events) >= limit: defer.returnValue(signed_events) - servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + servers = set(get_domain_from_id(u) for u in users) servers = set(servers) servers.discard(self.server_name) @@ -683,14 +798,16 @@ class FederationClient(FederationBase): return srvs deferreds = [ - self.get_pdu( + preserve_fn(self.get_pdu)( destinations=random_server_list(), event_id=e_id, ) for e_id, depth in ordered_missing[:limit - len(signed_events)] ] - res = yield defer.DeferredList(deferreds, consumeErrors=True) + res = yield preserve_context_over_deferred( + defer.DeferredList(deferreds, consumeErrors=True) + ) for (result, val), (e_id, _) in zip(res, ordered_missing): if result and val: signed_events.append(val)