diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 83c1f46586..b255709165 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
from .federation_base import FederationBase
from synapse.api.constants import Membership
-from .units import Edu
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
@@ -26,7 +25,9 @@ from synapse.api.errors import (
from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent
+from synapse.types import get_domain_from_id
import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@@ -43,14 +44,37 @@ logger = logging.getLogger(__name__)
# synapse.federation.federation_client is a silly name
metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
-sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations")
+sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
-sent_edus_counter = metrics.register_counter("sent_edus")
-sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
+PDU_RETRY_TIME_MS = 1 * 60 * 1000
class FederationClient(FederationBase):
+ def __init__(self, hs):
+ super(FederationClient, self).__init__(hs)
+
+ self.pdu_destination_tried = {}
+ self._clock.looping_call(
+ self._clear_tried_cache, 60 * 1000,
+ )
+ self.state = hs.get_state_handler()
+
+ def _clear_tried_cache(self):
+ """Clear pdu_destination_tried cache"""
+ now = self._clock.time_msec()
+
+ old_dict = self.pdu_destination_tried
+ self.pdu_destination_tried = {}
+
+ for event_id, destination_dict in old_dict.items():
+ destination_dict = {
+ dest: time
+ for dest, time in destination_dict.items()
+ if time + PDU_RETRY_TIME_MS > now
+ }
+ if destination_dict:
+ self.pdu_destination_tried[event_id] = destination_dict
def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache(
@@ -64,55 +88,6 @@ class FederationClient(FederationBase):
self._get_pdu_cache.start()
@log_function
- def send_pdu(self, pdu, destinations):
- """Informs the replication layer about a new PDU generated within the
- home server that should be transmitted to others.
-
- TODO: Figure out when we should actually resolve the deferred.
-
- Args:
- pdu (Pdu): The new Pdu.
-
- Returns:
- Deferred: Completes when we have successfully processed the PDU
- and replicated it to any interested remote home servers.
- """
- order = self._order
- self._order += 1
-
- sent_pdus_destination_dist.inc_by(len(destinations))
-
- logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
-
- # TODO, add errback, etc.
- self._transaction_queue.enqueue_pdu(pdu, destinations, order)
-
- logger.debug(
- "[%s] transaction_layer.enqueue_pdu... done",
- pdu.event_id
- )
-
- @log_function
- def send_edu(self, destination, edu_type, content):
- edu = Edu(
- origin=self.server_name,
- destination=destination,
- edu_type=edu_type,
- content=content,
- )
-
- sent_edus_counter.inc()
-
- # TODO, add errback, etc.
- self._transaction_queue.enqueue_edu(edu)
- return defer.succeed(None)
-
- @log_function
- def send_failure(self, failure, destination):
- self._transaction_queue.enqueue_failure(failure, destination)
- return defer.succeed(None)
-
- @log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=False):
"""Sends a federation Query to a remote homeserver of the given type
@@ -136,7 +111,7 @@ class FederationClient(FederationBase):
)
@log_function
- def query_client_keys(self, destination, content):
+ def query_client_keys(self, destination, content, timeout):
"""Query device keys for a device hosted on a remote server.
Args:
@@ -148,10 +123,12 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.inc("client_device_keys")
- return self.transport_layer.query_client_keys(destination, content)
+ return self.transport_layer.query_client_keys(
+ destination, content, timeout
+ )
@log_function
- def claim_client_keys(self, destination, content):
+ def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server.
Args:
@@ -163,7 +140,9 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.inc("client_one_time_keys")
- return self.transport_layer.claim_client_keys(destination, content)
+ return self.transport_layer.claim_client_keys(
+ destination, content, timeout
+ )
@defer.inlineCallbacks
@log_function
@@ -198,10 +177,10 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield defer.gatherResults(
+ pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
defer.returnValue(pdus)
@@ -233,12 +212,19 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache:
- e = self._get_pdu_cache.get(event_id)
- if e:
- defer.returnValue(e)
+ ev = self._get_pdu_cache.get(event_id)
+ if ev:
+ defer.returnValue(ev)
- pdu = None
+ pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+
+ signed_pdu = None
for destination in destinations:
+ now = self._clock.time_msec()
+ last_attempt = pdu_attempts.get(destination, 0)
+ if last_attempt + PDU_RETRY_TIME_MS > now:
+ continue
+
try:
limiter = yield get_retry_limiter(
destination,
@@ -262,39 +248,33 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- pdu = yield self._check_sigs_and_hashes([pdu])[0]
+ signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
- except SynapseError:
- logger.info(
- "Failed to get PDU %s from %s because %s",
- event_id, destination, e,
- )
- continue
- except CodeMessageException as e:
- if 400 <= e.code < 500:
- raise
+ pdu_attempts[destination] = now
+ except SynapseError as e:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
- continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
+ pdu_attempts[destination] = now
+
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
- if self._get_pdu_cache is not None and pdu:
- self._get_pdu_cache[event_id] = pdu
+ if self._get_pdu_cache is not None and signed_pdu:
+ self._get_pdu_cache[event_id] = signed_pdu
- defer.returnValue(pdu)
+ defer.returnValue(signed_pdu)
@defer.inlineCallbacks
@log_function
@@ -311,6 +291,42 @@ class FederationClient(FederationBase):
Deferred: Results in a list of PDUs.
"""
+ try:
+ # First we try and ask for just the IDs, as thats far quicker if
+ # we have most of the state and auth_chain already.
+ # However, this may 404 if the other side has an old synapse.
+ result = yield self.transport_layer.get_room_state_ids(
+ destination, room_id, event_id=event_id,
+ )
+
+ state_event_ids = result["pdu_ids"]
+ auth_event_ids = result.get("auth_chain_ids", [])
+
+ fetched_events, failed_to_fetch = yield self.get_events(
+ [destination], room_id, set(state_event_ids + auth_event_ids)
+ )
+
+ if failed_to_fetch:
+ logger.warn("Failed to get %r", failed_to_fetch)
+
+ event_map = {
+ ev.event_id: ev for ev in fetched_events
+ }
+
+ pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
+ auth_chain = [
+ event_map[e_id] for e_id in auth_event_ids if e_id in event_map
+ ]
+
+ auth_chain.sort(key=lambda e: e.depth)
+
+ defer.returnValue((pdus, auth_chain))
+ except HttpResponseException as e:
+ if e.code == 400 or e.code == 404:
+ logger.info("Failed to use get_room_state_ids API, falling back")
+ else:
+ raise e
+
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
)
@@ -324,12 +340,26 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", [])
]
+ seen_events = yield self.store.get_events([
+ ev.event_id for ev in itertools.chain(pdus, auth_chain)
+ ])
+
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
- destination, pdus, outlier=True
+ destination,
+ [p for p in pdus if p.event_id not in seen_events],
+ outlier=True
+ )
+ signed_pdus.extend(
+ seen_events[p.event_id] for p in pdus if p.event_id in seen_events
)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True
+ destination,
+ [p for p in auth_chain if p.event_id not in seen_events],
+ outlier=True
+ )
+ signed_auth.extend(
+ seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
)
signed_auth.sort(key=lambda e: e.depth)
@@ -337,6 +367,69 @@ class FederationClient(FederationBase):
defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
+ def get_events(self, destinations, room_id, event_ids, return_local=True):
+ """Fetch events from some remote destinations, checking if we already
+ have them.
+
+ Args:
+ destinations (list)
+ room_id (str)
+ event_ids (list)
+ return_local (bool): Whether to include events we already have in
+ the DB in the returned list of events
+
+ Returns:
+ Deferred: A deferred resolving to a 2-tuple where the first is a list of
+ events and the second is a list of event ids that we failed to fetch.
+ """
+ if return_local:
+ seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
+ signed_events = seen_events.values()
+ else:
+ seen_events = yield self.store.have_events(event_ids)
+ signed_events = []
+
+ failed_to_fetch = set()
+
+ missing_events = set(event_ids)
+ for k in seen_events:
+ missing_events.discard(k)
+
+ if not missing_events:
+ defer.returnValue((signed_events, failed_to_fetch))
+
+ def random_server_list():
+ srvs = list(destinations)
+ random.shuffle(srvs)
+ return srvs
+
+ batch_size = 20
+ missing_events = list(missing_events)
+ for i in xrange(0, len(missing_events), batch_size):
+ batch = set(missing_events[i:i + batch_size])
+
+ deferreds = [
+ preserve_fn(self.get_pdu)(
+ destinations=random_server_list(),
+ event_id=e_id,
+ )
+ for e_id in batch
+ ]
+
+ res = yield preserve_context_over_deferred(
+ defer.DeferredList(deferreds, consumeErrors=True)
+ )
+ for success, result in res:
+ if success and result:
+ signed_events.append(result)
+ batch.discard(result.event_id)
+
+ # We removed all events we successfully fetched from `batch`
+ failed_to_fetch.update(batch)
+
+ defer.returnValue((signed_events, failed_to_fetch))
+
+ @defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
@@ -411,8 +504,14 @@ class FederationClient(FederationBase):
(destination, self.event_from_pdu_json(pdu_dict))
)
break
- except CodeMessageException:
- raise
+ except CodeMessageException as e:
+ if not 500 <= e.code < 600:
+ raise
+ else:
+ logger.warn(
+ "Failed to make_%s via %s: %s",
+ membership, destination, e.message
+ )
except Exception as e:
logger.warn(
"Failed to make_%s via %s: %s",
@@ -489,8 +588,14 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth,
"origin": destination,
})
- except CodeMessageException:
- raise
+ except CodeMessageException as e:
+ if not 500 <= e.code < 600:
+ raise
+ else:
+ logger.exception(
+ "Failed to send_join via %s: %s",
+ destination, e.message
+ )
except Exception as e:
logger.exception(
"Failed to send_join via %s: %s",
@@ -549,6 +654,15 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.")
+ def get_public_rooms(self, destination, limit=None, since_token=None,
+ search_filter=None):
+ if destination == self.server_name:
+ return
+
+ return self.transport_layer.get_public_rooms(
+ destination, limit, since_token, search_filter
+ )
+
@defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth):
"""
@@ -638,7 +752,8 @@ class FederationClient(FederationBase):
if len(signed_events) >= limit:
defer.returnValue(signed_events)
- servers = yield self.store.get_joined_hosts_for_room(room_id)
+ users = yield self.state.get_current_user_in_room(room_id)
+ servers = set(get_domain_from_id(u) for u in users)
servers = set(servers)
servers.discard(self.server_name)
@@ -683,14 +798,16 @@ class FederationClient(FederationBase):
return srvs
deferreds = [
- self.get_pdu(
+ preserve_fn(self.get_pdu)(
destinations=random_server_list(),
event_id=e_id,
)
for e_id, depth in ordered_missing[:limit - len(signed_events)]
]
- res = yield defer.DeferredList(deferreds, consumeErrors=True)
+ res = yield preserve_context_over_deferred(
+ defer.DeferredList(deferreds, consumeErrors=True)
+ )
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val:
signed_events.append(val)
|