diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d42930d1b9..688d43fffb 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -79,7 +79,7 @@ class InvalidResponseError(RuntimeError):
class FederationClient(FederationBase):
def __init__(self, hs):
- super(FederationClient, self).__init__(hs)
+ super().__init__(hs)
self.pdu_destination_tried = {}
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ff00f0b302..2dcd081cbc 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -90,7 +90,7 @@ pdu_process_time = Histogram(
class FederationServer(FederationBase):
def __init__(self, hs):
- super(FederationServer, self).__init__(hs)
+ super().__init__(hs)
self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 552519e82c..8bb17b3a05 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -55,6 +55,15 @@ sent_pdus_destination_dist_total = Counter(
"Total number of PDUs queued for sending across all destinations",
)
+# Time (in s) after Synapse's startup that we will begin to wake up destinations
+# that have catch-up outstanding.
+CATCH_UP_STARTUP_DELAY_SEC = 15
+
+# Time (in s) to wait in between waking up each destination, i.e. one destination
+# will be woken up every <x> seconds after Synapse's startup until we have woken
+# every destination has outstanding catch-up.
+CATCH_UP_STARTUP_INTERVAL_SEC = 5
+
class FederationSender:
def __init__(self, hs: "synapse.server.HomeServer"):
@@ -125,6 +134,14 @@ class FederationSender:
1000.0 / hs.config.federation_rr_transactions_per_room_per_second
)
+ # wake up destinations that have outstanding PDUs to be caught up
+ self._catchup_after_startup_timer = self.clock.call_later(
+ CATCH_UP_STARTUP_DELAY_SEC,
+ run_as_background_process,
+ "wake_destinations_needing_catchup",
+ self._wake_destinations_needing_catchup,
+ )
+
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
@@ -209,7 +226,7 @@ class FederationSender:
logger.debug("Sending %s to %r", event, destinations)
if destinations:
- self._send_pdu(event, destinations)
+ await self._send_pdu(event, destinations)
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
@@ -265,7 +282,7 @@ class FederationSender:
finally:
self._is_processing = False
- def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
+ async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
@@ -280,6 +297,13 @@ class FederationSender:
sent_pdus_destination_dist_total.inc(len(destinations))
sent_pdus_destination_dist_count.inc()
+ # track the fact that we have a PDU for these destinations,
+ # to allow us to perform catch-up later on if the remote is unreachable
+ # for a while.
+ await self.store.store_destination_rooms_entries(
+ destinations, pdu.room_id, pdu.internal_metadata.stream_ordering,
+ )
+
for destination in destinations:
self._get_per_destination_queue(destination).send_pdu(pdu)
@@ -553,3 +577,37 @@ class FederationSender:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return [], 0, False
+
+ async def _wake_destinations_needing_catchup(self):
+ """
+ Wakes up destinations that need catch-up and are not currently being
+ backed off from.
+
+ In order to reduce load spikes, adds a delay between each destination.
+ """
+
+ last_processed = None # type: Optional[str]
+
+ while True:
+ destinations_to_wake = await self.store.get_catch_up_outstanding_destinations(
+ last_processed
+ )
+
+ if not destinations_to_wake:
+ # finished waking all destinations!
+ self._catchup_after_startup_timer = None
+ break
+
+ destinations_to_wake = [
+ d
+ for d in destinations_to_wake
+ if self._federation_shard_config.should_handle(self._instance_name, d)
+ ]
+
+ for last_processed in destinations_to_wake:
+ logger.info(
+ "Destination %s has outstanding catch-up, waking up.",
+ last_processed,
+ )
+ self.wake_destination(last_processed)
+ await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index defc228c23..2657767fd1 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,7 +15,7 @@
# limitations under the License.
import datetime
import logging
-from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
from prometheus_client import Counter
@@ -92,6 +92,21 @@ class PerDestinationQueue:
self._destination = destination
self.transmission_loop_running = False
+ # True whilst we are sending events that the remote homeserver missed
+ # because it was unreachable. We start in this state so we can perform
+ # catch-up at startup.
+ # New events will only be sent once this is finished, at which point
+ # _catching_up is flipped to False.
+ self._catching_up = True # type: bool
+
+ # The stream_ordering of the most recent PDU that was discarded due to
+ # being in catch-up mode.
+ self._catchup_last_skipped = 0 # type: int
+
+ # Cache of the last successfully-transmitted stream ordering for this
+ # destination (we are the only updater so this is safe)
+ self._last_successful_stream_ordering = None # type: Optional[int]
+
# a list of pending PDUs
self._pending_pdus = [] # type: List[EventBase]
@@ -138,7 +153,13 @@ class PerDestinationQueue:
Args:
pdu: pdu to send
"""
- self._pending_pdus.append(pdu)
+ if not self._catching_up or self._last_successful_stream_ordering is None:
+ # only enqueue the PDU if we are not catching up (False) or do not
+ # yet know if we have anything to catch up (None)
+ self._pending_pdus.append(pdu)
+ else:
+ self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
+
self.attempt_new_transaction()
def send_presence(self, states: Iterable[UserPresenceState]) -> None:
@@ -218,6 +239,13 @@ class PerDestinationQueue:
# hence why we throw the result away.
await get_retry_limiter(self._destination, self._clock, self._store)
+ if self._catching_up:
+ # we potentially need to catch-up first
+ await self._catch_up_transmission_loop()
+ if self._catching_up:
+ # not caught up yet
+ return
+
pending_pdus = []
while True:
# We have to keep 2 free slots for presence and rr_edus
@@ -325,6 +353,17 @@ class PerDestinationQueue:
self._last_device_stream_id = device_stream_id
self._last_device_list_stream_id = dev_list_id
+
+ if pending_pdus:
+ # we sent some PDUs and it was successful, so update our
+ # last_successful_stream_ordering in the destinations table.
+ final_pdu = pending_pdus[-1]
+ last_successful_stream_ordering = (
+ final_pdu.internal_metadata.stream_ordering
+ )
+ await self._store.set_destination_last_successful_stream_ordering(
+ self._destination, last_successful_stream_ordering
+ )
else:
break
except NotRetryingDestination as e:
@@ -340,8 +379,9 @@ class PerDestinationQueue:
if e.retry_interval > 60 * 60 * 1000:
# we won't retry for another hour!
# (this suggests a significant outage)
- # We drop pending PDUs and EDUs because otherwise they will
+ # We drop pending EDUs because otherwise they will
# rack up indefinitely.
+ # (Dropping PDUs is already performed by `_start_catching_up`.)
# Note that:
# - the EDUs that are being dropped here are those that we can
# afford to drop (specifically, only typing notifications,
@@ -353,11 +393,12 @@ class PerDestinationQueue:
# dropping read receipts is a bit sad but should be solved
# through another mechanism, because this is all volatile!
- self._pending_pdus = []
self._pending_edus = []
self._pending_edus_keyed = {}
self._pending_presence = {}
self._pending_rrs = {}
+
+ self._start_catching_up()
except FederationDeniedError as e:
logger.info(e)
except HttpResponseException as e:
@@ -367,6 +408,8 @@ class PerDestinationQueue:
e.code,
e,
)
+
+ self._start_catching_up()
except RequestSendFailed as e:
logger.warning(
"TX [%s] Failed to send transaction: %s", self._destination, e
@@ -376,16 +419,96 @@ class PerDestinationQueue:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
+
+ self._start_catching_up()
except Exception:
logger.exception("TX [%s] Failed to send transaction", self._destination)
for p in pending_pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
+
+ self._start_catching_up()
finally:
# We want to be *very* sure we clear this after we stop processing
self.transmission_loop_running = False
+ async def _catch_up_transmission_loop(self) -> None:
+ first_catch_up_check = self._last_successful_stream_ordering is None
+
+ if first_catch_up_check:
+ # first catchup so get last_successful_stream_ordering from database
+ self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering(
+ self._destination
+ )
+
+ if self._last_successful_stream_ordering is None:
+ # if it's still None, then this means we don't have the information
+ # in our database we haven't successfully sent a PDU to this server
+ # (at least since the introduction of the feature tracking
+ # last_successful_stream_ordering).
+ # Sadly, this means we can't do anything here as we don't know what
+ # needs catching up — so catching up is futile; let's stop.
+ self._catching_up = False
+ return
+
+ # get at most 50 catchup room/PDUs
+ while True:
+ event_ids = await self._store.get_catch_up_room_event_ids(
+ self._destination, self._last_successful_stream_ordering,
+ )
+
+ if not event_ids:
+ # No more events to catch up on, but we can't ignore the chance
+ # of a race condition, so we check that no new events have been
+ # skipped due to us being in catch-up mode
+
+ if self._catchup_last_skipped > self._last_successful_stream_ordering:
+ # another event has been skipped because we were in catch-up mode
+ continue
+
+ # we are done catching up!
+ self._catching_up = False
+ break
+
+ if first_catch_up_check:
+ # as this is our check for needing catch-up, we may have PDUs in
+ # the queue from before we *knew* we had to do catch-up, so
+ # clear those out now.
+ self._start_catching_up()
+
+ # fetch the relevant events from the event store
+ # - redacted behaviour of REDACT is fine, since we only send metadata
+ # of redacted events to the destination.
+ # - don't need to worry about rejected events as we do not actively
+ # forward received events over federation.
+ catchup_pdus = await self._store.get_events_as_list(event_ids)
+ if not catchup_pdus:
+ raise AssertionError(
+ "No events retrieved when we asked for %r. "
+ "This should not happen." % event_ids
+ )
+
+ if logger.isEnabledFor(logging.INFO):
+ rooms = (p.room_id for p in catchup_pdus)
+ logger.info("Catching up rooms to %s: %r", self._destination, rooms)
+
+ success = await self._transaction_manager.send_new_transaction(
+ self._destination, catchup_pdus, []
+ )
+
+ if not success:
+ return
+
+ sent_transactions_counter.inc()
+ final_pdu = catchup_pdus[-1]
+ self._last_successful_stream_ordering = cast(
+ int, final_pdu.internal_metadata.stream_ordering
+ )
+ await self._store.set_destination_last_successful_stream_ordering(
+ self._destination, self._last_successful_stream_ordering
+ )
+
def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
if not self._pending_rrs:
return
@@ -446,3 +569,12 @@ class PerDestinationQueue:
]
return (edus, stream_id)
+
+ def _start_catching_up(self) -> None:
+ """
+ Marks this destination as being in catch-up mode.
+
+ This throws away the PDU queue.
+ """
+ self._catching_up = True
+ self._pending_pdus = []
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 17a10f622e..4f7996f947 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -16,7 +16,7 @@
import logging
import urllib
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
@@ -1004,6 +1004,20 @@ class TransportLayerClient:
return self.client.get_json(destination=destination, path=path)
+ def get_info_of_users(self, destination: str, user_ids: List[str]):
+ """
+ Args:
+ destination: The remote server
+ user_ids: A list of user IDs to query info about
+
+ Returns:
+ Deferred[List]: A dictionary of User ID to information about that user.
+ """
+ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info")
+ data = {"user_ids": user_ids}
+
+ return self.client.post_json(destination=destination, path=path, data=data)
+
def _create_path(federation_prefix, path, *args):
"""
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index cc7e9a973b..7b4baddbf8 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -31,6 +31,7 @@ from synapse.api.urls import (
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
+ assert_params_in_dict,
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
@@ -68,7 +69,7 @@ class TransportLayerServer(JsonResource):
self.clock = hs.get_clock()
self.servlet_groups = servlet_groups
- super(TransportLayerServer, self).__init__(hs, canonical_json=False)
+ super().__init__(hs, canonical_json=False)
self.authenticator = Authenticator(hs)
self.ratelimiter = hs.get_federation_ratelimiter()
@@ -376,9 +377,7 @@ class FederationSendServlet(BaseFederationServlet):
RATELIMIT = False
def __init__(self, handler, server_name, **kwargs):
- super(FederationSendServlet, self).__init__(
- handler, server_name=server_name, **kwargs
- )
+ super().__init__(handler, server_name=server_name, **kwargs)
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.
@@ -773,9 +772,7 @@ class PublicRoomList(BaseFederationServlet):
PATH = "/publicRooms"
def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
- super(PublicRoomList, self).__init__(
- handler, authenticator, ratelimiter, server_name
- )
+ super().__init__(handler, authenticator, ratelimiter, server_name)
self.allow_access = allow_access
async def on_GET(self, origin, content, query):
@@ -848,6 +845,57 @@ class PublicRoomList(BaseFederationServlet):
return 200, data
+class FederationUserInfoServlet(BaseFederationServlet):
+ """
+ Return information about a set of users.
+
+ This API returns expiration and deactivation information about a set of
+ users. Requested users not local to this homeserver will be ignored.
+
+ Example request:
+ POST /users/info
+
+ {
+ "user_ids": [
+ "@alice:example.com",
+ "@bob:example.com"
+ ]
+ }
+
+ Example response
+ {
+ "@alice:example.com": {
+ "expired": false,
+ "deactivated": true
+ }
+ }
+ """
+
+ PATH = "/users/info"
+ PREFIX = FEDERATION_UNSTABLE_PREFIX
+
+ def __init__(self, handler, authenticator, ratelimiter, server_name):
+ super(FederationUserInfoServlet, self).__init__(
+ handler, authenticator, ratelimiter, server_name
+ )
+ self.handler = handler
+
+ async def on_POST(self, origin, content, query):
+ assert_params_in_dict(content, required=["user_ids"])
+
+ user_ids = content.get("user_ids", [])
+
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ data = await self.handler.store.get_info_for_users(user_ids)
+ return 200, data
+
+
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
@@ -1409,6 +1457,7 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
+ FederationUserInfoServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
|