summary refs log tree commit diff
path: root/synapse/federation/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r--synapse/federation/federation_client.py89
1 files changed, 74 insertions, 15 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d3b46b24c1..f5e346cdbc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -23,13 +23,14 @@ from synapse.api.errors import (
     CodeMessageException, HttpResponseException, SynapseError,
 )
 from synapse.util import unwrapFirstError
-from synapse.util.expiringcache import ExpiringCache
+from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.events import FrozenEvent
 import synapse.metrics
 
 from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
 
+import copy
 import itertools
 import logging
 import random
@@ -133,6 +134,36 @@ class FederationClient(FederationBase):
             destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
         )
 
+    @log_function
+    def query_client_keys(self, destination, content):
+        """Query device keys for a device hosted on a remote server.
+
+        Args:
+            destination (str): Domain name of the remote homeserver
+            content (dict): The query content.
+
+        Returns:
+            a Deferred which will eventually yield a JSON object from the
+            response
+        """
+        sent_queries_counter.inc("client_device_keys")
+        return self.transport_layer.query_client_keys(destination, content)
+
+    @log_function
+    def claim_client_keys(self, destination, content):
+        """Claims one-time keys for a device hosted on a remote server.
+
+        Args:
+            destination (str): Domain name of the remote homeserver
+            content (dict): The query content.
+
+        Returns:
+            a Deferred which will eventually yield a JSON object from the
+            response
+        """
+        sent_queries_counter.inc("client_one_time_keys")
+        return self.transport_layer.claim_client_keys(destination, content)
+
     @defer.inlineCallbacks
     @log_function
     def backfill(self, dest, context, limit, extremities):
@@ -167,7 +198,7 @@ class FederationClient(FederationBase):
 
         # FIXME: We should handle signature failures more gracefully.
         pdus[:] = yield defer.gatherResults(
-            [self._check_sigs_and_hash(pdu) for pdu in pdus],
+            self._check_sigs_and_hashes(pdus),
             consumeErrors=True,
         ).addErrback(unwrapFirstError)
 
@@ -230,7 +261,7 @@ class FederationClient(FederationBase):
                         pdu = pdu_list[0]
 
                         # Check signatures are correct.
-                        pdu = yield self._check_sigs_and_hash(pdu)
+                        pdu = yield self._check_sigs_and_hashes([pdu])[0]
 
                         break
 
@@ -327,6 +358,9 @@ class FederationClient(FederationBase):
     @defer.inlineCallbacks
     def make_join(self, destinations, room_id, user_id):
         for destination in destinations:
+            if destination == self.server_name:
+                continue
+
             try:
                 ret = yield self.transport_layer.make_join(
                     destination, room_id, user_id
@@ -353,6 +387,9 @@ class FederationClient(FederationBase):
     @defer.inlineCallbacks
     def send_join(self, destinations, pdu):
         for destination in destinations:
+            if destination == self.server_name:
+                continue
+
             try:
                 time_now = self._clock.time_msec()
                 _, content = yield self.transport_layer.send_join(
@@ -374,17 +411,39 @@ class FederationClient(FederationBase):
                     for p in content.get("auth_chain", [])
                 ]
 
-                signed_state, signed_auth = yield defer.gatherResults(
-                    [
-                        self._check_sigs_and_hash_and_fetch(
-                            destination, state, outlier=True
-                        ),
-                        self._check_sigs_and_hash_and_fetch(
-                            destination, auth_chain, outlier=True
-                        )
-                    ],
-                    consumeErrors=True
-                ).addErrback(unwrapFirstError)
+                pdus = {
+                    p.event_id: p
+                    for p in itertools.chain(state, auth_chain)
+                }
+
+                valid_pdus = yield self._check_sigs_and_hash_and_fetch(
+                    destination, pdus.values(),
+                    outlier=True,
+                )
+
+                valid_pdus_map = {
+                    p.event_id: p
+                    for p in valid_pdus
+                }
+
+                # NB: We *need* to copy to ensure that we don't have multiple
+                # references being passed on, as that causes... issues.
+                signed_state = [
+                    copy.copy(valid_pdus_map[p.event_id])
+                    for p in state
+                    if p.event_id in valid_pdus_map
+                ]
+
+                signed_auth = [
+                    valid_pdus_map[p.event_id]
+                    for p in auth_chain
+                    if p.event_id in valid_pdus_map
+                ]
+
+                # NB: We *need* to copy to ensure that we don't have multiple
+                # references being passed on, as that causes... issues.
+                for s in signed_state:
+                    s.internal_metadata = copy.deepcopy(s.internal_metadata)
 
                 auth_chain.sort(key=lambda e: e.depth)
 
@@ -396,7 +455,7 @@ class FederationClient(FederationBase):
             except CodeMessageException:
                 raise
             except Exception as e:
-                logger.warn(
+                logger.exception(
                     "Failed to send_join via %s: %s",
                     destination, e.message
                 )