diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d3b46b24c1..7736d14fb5 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -30,6 +30,7 @@ import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+import copy
import itertools
import logging
import random
@@ -167,7 +168,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults(
- [self._check_sigs_and_hash(pdu) for pdu in pdus],
+ self._check_sigs_and_hashes(pdus),
consumeErrors=True,
).addErrback(unwrapFirstError)
@@ -230,7 +231,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- pdu = yield self._check_sigs_and_hash(pdu)
+ pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
@@ -327,6 +328,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id):
for destination in destinations:
+ if destination == self.server_name:
+ continue
+
try:
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
@@ -353,6 +357,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_join(self, destinations, pdu):
for destination in destinations:
+ if destination == self.server_name:
+ continue
+
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
@@ -374,17 +381,39 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", [])
]
- signed_state, signed_auth = yield defer.gatherResults(
- [
- self._check_sigs_and_hash_and_fetch(
- destination, state, outlier=True
- ),
- self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True
- )
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError)
+ pdus = {
+ p.event_id: p
+ for p in itertools.chain(state, auth_chain)
+ }
+
+ valid_pdus = yield self._check_sigs_and_hash_and_fetch(
+ destination, pdus.values(),
+ outlier=True,
+ )
+
+ valid_pdus_map = {
+ p.event_id: p
+ for p in valid_pdus
+ }
+
+ # NB: We *need* to copy to ensure that we don't have multiple
+ # references being passed on, as that causes... issues.
+ signed_state = [
+ copy.copy(valid_pdus_map[p.event_id])
+ for p in state
+ if p.event_id in valid_pdus_map
+ ]
+
+ signed_auth = [
+ valid_pdus_map[p.event_id]
+ for p in auth_chain
+ if p.event_id in valid_pdus_map
+ ]
+
+ # NB: We *need* to copy to ensure that we don't have multiple
+ # references being passed on, as that causes... issues.
+ for s in signed_state:
+ s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth)
@@ -396,7 +425,7 @@ class FederationClient(FederationBase):
except CodeMessageException:
raise
except Exception as e:
- logger.warn(
+ logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
|