diff options
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r-- | synapse/federation/federation_client.py | 89 |
1 files changed, 74 insertions, 15 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d3b46b24c1..f5e346cdbc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -23,13 +23,14 @@ from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, ) from synapse.util import unwrapFirstError -from synapse.util.expiringcache import ExpiringCache +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination +import copy import itertools import logging import random @@ -133,6 +134,36 @@ class FederationClient(FederationBase): destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail ) + @log_function + def query_client_keys(self, destination, content): + """Query device keys for a device hosted on a remote server. + + Args: + destination (str): Domain name of the remote homeserver + content (dict): The query content. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + sent_queries_counter.inc("client_device_keys") + return self.transport_layer.query_client_keys(destination, content) + + @log_function + def claim_client_keys(self, destination, content): + """Claims one-time keys for a device hosted on a remote server. + + Args: + destination (str): Domain name of the remote homeserver + content (dict): The query content. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + sent_queries_counter.inc("client_one_time_keys") + return self.transport_layer.claim_client_keys(destination, content) + @defer.inlineCallbacks @log_function def backfill(self, dest, context, limit, extremities): @@ -167,7 +198,7 @@ class FederationClient(FederationBase): # FIXME: We should handle signature failures more gracefully. pdus[:] = yield defer.gatherResults( - [self._check_sigs_and_hash(pdu) for pdu in pdus], + self._check_sigs_and_hashes(pdus), consumeErrors=True, ).addErrback(unwrapFirstError) @@ -230,7 +261,7 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - pdu = yield self._check_sigs_and_hash(pdu) + pdu = yield self._check_sigs_and_hashes([pdu])[0] break @@ -327,6 +358,9 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def make_join(self, destinations, room_id, user_id): for destination in destinations: + if destination == self.server_name: + continue + try: ret = yield self.transport_layer.make_join( destination, room_id, user_id @@ -353,6 +387,9 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_join(self, destinations, pdu): for destination in destinations: + if destination == self.server_name: + continue + try: time_now = self._clock.time_msec() _, content = yield self.transport_layer.send_join( @@ -374,17 +411,39 @@ class FederationClient(FederationBase): for p in content.get("auth_chain", []) ] - signed_state, signed_auth = yield defer.gatherResults( - [ - self._check_sigs_and_hash_and_fetch( - destination, state, outlier=True - ), - self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True - ) - ], - consumeErrors=True - ).addErrback(unwrapFirstError) + pdus = { + p.event_id: p + for p in itertools.chain(state, auth_chain) + } + + valid_pdus = yield self._check_sigs_and_hash_and_fetch( + destination, pdus.values(), + outlier=True, + ) + + valid_pdus_map = { + p.event_id: p + for p in valid_pdus + } + + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + signed_state = [ + copy.copy(valid_pdus_map[p.event_id]) + for p in state + if p.event_id in valid_pdus_map + ] + + signed_auth = [ + valid_pdus_map[p.event_id] + for p in auth_chain + if p.event_id in valid_pdus_map + ] + + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + for s in signed_state: + s.internal_metadata = copy.deepcopy(s.internal_metadata) auth_chain.sort(key=lambda e: e.depth) @@ -396,7 +455,7 @@ class FederationClient(FederationBase): except CodeMessageException: raise except Exception as e: - logger.warn( + logger.exception( "Failed to send_join via %s: %s", destination, e.message ) |