summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py6
-rw-r--r--synapse/federation/federation_server.py26
-rw-r--r--synapse/federation/sender/__init__.py62
-rw-r--r--synapse/federation/sender/per_destination_queue.py140
-rw-r--r--synapse/federation/sender/transaction_manager.py22
-rw-r--r--synapse/federation/transport/server.py10
6 files changed, 249 insertions, 17 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d42930d1b9..302b2f69bc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -24,10 +24,12 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Sequence,
     Tuple,
     TypeVar,
+    Union,
 )
 
 from prometheus_client import Counter
@@ -79,7 +81,7 @@ class InvalidResponseError(RuntimeError):
 
 class FederationClient(FederationBase):
     def __init__(self, hs):
-        super(FederationClient, self).__init__(hs)
+        super().__init__(hs)
 
         self.pdu_destination_tried = {}
         self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
@@ -501,7 +503,7 @@ class FederationClient(FederationBase):
         user_id: str,
         membership: str,
         content: dict,
-        params: Dict[str, str],
+        params: Optional[Mapping[str, Union[str, Iterable[str]]]],
     ) -> Tuple[str, EventBase, RoomVersion]:
         """
         Creates an m.room.member event, with context, without participating in the room.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ff00f0b302..24329dd0e3 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -28,7 +28,7 @@ from typing import (
     Union,
 )
 
-from prometheus_client import Counter, Histogram
+from prometheus_client import Counter, Gauge, Histogram
 
 from twisted.internet import defer
 from twisted.internet.abstract import isIPAddress
@@ -88,9 +88,16 @@ pdu_process_time = Histogram(
 )
 
 
+last_pdu_age_metric = Gauge(
+    "synapse_federation_last_received_pdu_age",
+    "The age (in seconds) of the last PDU successfully received from the given domain",
+    labelnames=("server_name",),
+)
+
+
 class FederationServer(FederationBase):
     def __init__(self, hs):
-        super(FederationServer, self).__init__(hs)
+        super().__init__(hs)
 
         self.auth = hs.get_auth()
         self.handler = hs.get_handlers().federation_handler
@@ -118,6 +125,10 @@ class FederationServer(FederationBase):
             hs, "state_ids_resp", timeout_ms=30000
         )
 
+        self._federation_metrics_domains = (
+            hs.get_config().federation.federation_metrics_domains
+        )
+
     async def on_backfill_request(
         self, origin: str, room_id: str, versions: List[str], limit: int
     ) -> Tuple[int, Dict[str, Any]]:
@@ -262,7 +273,11 @@ class FederationServer(FederationBase):
 
         pdus_by_room = {}  # type: Dict[str, List[EventBase]]
 
+        newest_pdu_ts = 0
+
         for p in transaction.pdus:  # type: ignore
+            # FIXME (richardv): I don't think this works:
+            #  https://github.com/matrix-org/synapse/issues/8429
             if "unsigned" in p:
                 unsigned = p["unsigned"]
                 if "age" in unsigned:
@@ -300,6 +315,9 @@ class FederationServer(FederationBase):
             event = event_from_pdu_json(p, room_version)
             pdus_by_room.setdefault(room_id, []).append(event)
 
+            if event.origin_server_ts > newest_pdu_ts:
+                newest_pdu_ts = event.origin_server_ts
+
         pdu_results = {}
 
         # we can process different rooms in parallel (which is useful if they
@@ -340,6 +358,10 @@ class FederationServer(FederationBase):
             process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
         )
 
+        if newest_pdu_ts and origin in self._federation_metrics_domains:
+            newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
+            last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
+
         return pdu_results
 
     async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 552519e82c..8bb17b3a05 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -55,6 +55,15 @@ sent_pdus_destination_dist_total = Counter(
     "Total number of PDUs queued for sending across all destinations",
 )
 
+# Time (in s) after Synapse's startup that we will begin to wake up destinations
+# that have catch-up outstanding.
+CATCH_UP_STARTUP_DELAY_SEC = 15
+
+# Time (in s) to wait in between waking up each destination, i.e. one destination
+# will be woken up every <x> seconds after Synapse's startup until we have woken
+# every destination has outstanding catch-up.
+CATCH_UP_STARTUP_INTERVAL_SEC = 5
+
 
 class FederationSender:
     def __init__(self, hs: "synapse.server.HomeServer"):
@@ -125,6 +134,14 @@ class FederationSender:
             1000.0 / hs.config.federation_rr_transactions_per_room_per_second
         )
 
+        # wake up destinations that have outstanding PDUs to be caught up
+        self._catchup_after_startup_timer = self.clock.call_later(
+            CATCH_UP_STARTUP_DELAY_SEC,
+            run_as_background_process,
+            "wake_destinations_needing_catchup",
+            self._wake_destinations_needing_catchup,
+        )
+
     def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
         """Get or create a PerDestinationQueue for the given destination
 
@@ -209,7 +226,7 @@ class FederationSender:
                     logger.debug("Sending %s to %r", event, destinations)
 
                     if destinations:
-                        self._send_pdu(event, destinations)
+                        await self._send_pdu(event, destinations)
 
                         now = self.clock.time_msec()
                         ts = await self.store.get_received_ts(event.event_id)
@@ -265,7 +282,7 @@ class FederationSender:
         finally:
             self._is_processing = False
 
-    def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
+    async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
         # We loop through all destinations to see whether we already have
         # a transaction in progress. If we do, stick it in the pending_pdus
         # table and we'll get back to it later.
@@ -280,6 +297,13 @@ class FederationSender:
         sent_pdus_destination_dist_total.inc(len(destinations))
         sent_pdus_destination_dist_count.inc()
 
+        # track the fact that we have a PDU for these destinations,
+        # to allow us to perform catch-up later on if the remote is unreachable
+        # for a while.
+        await self.store.store_destination_rooms_entries(
+            destinations, pdu.room_id, pdu.internal_metadata.stream_ordering,
+        )
+
         for destination in destinations:
             self._get_per_destination_queue(destination).send_pdu(pdu)
 
@@ -553,3 +577,37 @@ class FederationSender:
         # Dummy implementation for case where federation sender isn't offloaded
         # to a worker.
         return [], 0, False
+
+    async def _wake_destinations_needing_catchup(self):
+        """
+        Wakes up destinations that need catch-up and are not currently being
+        backed off from.
+
+        In order to reduce load spikes, adds a delay between each destination.
+        """
+
+        last_processed = None  # type: Optional[str]
+
+        while True:
+            destinations_to_wake = await self.store.get_catch_up_outstanding_destinations(
+                last_processed
+            )
+
+            if not destinations_to_wake:
+                # finished waking all destinations!
+                self._catchup_after_startup_timer = None
+                break
+
+            destinations_to_wake = [
+                d
+                for d in destinations_to_wake
+                if self._federation_shard_config.should_handle(self._instance_name, d)
+            ]
+
+            for last_processed in destinations_to_wake:
+                logger.info(
+                    "Destination %s has outstanding catch-up, waking up.",
+                    last_processed,
+                )
+                self.wake_destination(last_processed)
+                await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index defc228c23..bc99af3fdd 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 import datetime
 import logging
-from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
 
 from prometheus_client import Counter
 
@@ -92,6 +92,21 @@ class PerDestinationQueue:
         self._destination = destination
         self.transmission_loop_running = False
 
+        # True whilst we are sending events that the remote homeserver missed
+        # because it was unreachable. We start in this state so we can perform
+        # catch-up at startup.
+        # New events will only be sent once this is finished, at which point
+        # _catching_up is flipped to False.
+        self._catching_up = True  # type: bool
+
+        # The stream_ordering of the most recent PDU that was discarded due to
+        # being in catch-up mode.
+        self._catchup_last_skipped = 0  # type: int
+
+        # Cache of the last successfully-transmitted stream ordering for this
+        # destination (we are the only updater so this is safe)
+        self._last_successful_stream_ordering = None  # type: Optional[int]
+
         # a list of pending PDUs
         self._pending_pdus = []  # type: List[EventBase]
 
@@ -138,7 +153,13 @@ class PerDestinationQueue:
         Args:
             pdu: pdu to send
         """
-        self._pending_pdus.append(pdu)
+        if not self._catching_up or self._last_successful_stream_ordering is None:
+            # only enqueue the PDU if we are not catching up (False) or do not
+            # yet know if we have anything to catch up (None)
+            self._pending_pdus.append(pdu)
+        else:
+            self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
+
         self.attempt_new_transaction()
 
     def send_presence(self, states: Iterable[UserPresenceState]) -> None:
@@ -218,6 +239,13 @@ class PerDestinationQueue:
             # hence why we throw the result away.
             await get_retry_limiter(self._destination, self._clock, self._store)
 
+            if self._catching_up:
+                # we potentially need to catch-up first
+                await self._catch_up_transmission_loop()
+                if self._catching_up:
+                    # not caught up yet
+                    return
+
             pending_pdus = []
             while True:
                 # We have to keep 2 free slots for presence and rr_edus
@@ -325,6 +353,17 @@ class PerDestinationQueue:
 
                     self._last_device_stream_id = device_stream_id
                     self._last_device_list_stream_id = dev_list_id
+
+                    if pending_pdus:
+                        # we sent some PDUs and it was successful, so update our
+                        # last_successful_stream_ordering in the destinations table.
+                        final_pdu = pending_pdus[-1]
+                        last_successful_stream_ordering = (
+                            final_pdu.internal_metadata.stream_ordering
+                        )
+                        await self._store.set_destination_last_successful_stream_ordering(
+                            self._destination, last_successful_stream_ordering
+                        )
                 else:
                     break
         except NotRetryingDestination as e:
@@ -340,8 +379,9 @@ class PerDestinationQueue:
             if e.retry_interval > 60 * 60 * 1000:
                 # we won't retry for another hour!
                 # (this suggests a significant outage)
-                # We drop pending PDUs and EDUs because otherwise they will
+                # We drop pending EDUs because otherwise they will
                 # rack up indefinitely.
+                # (Dropping PDUs is already performed by `_start_catching_up`.)
                 # Note that:
                 # - the EDUs that are being dropped here are those that we can
                 #   afford to drop (specifically, only typing notifications,
@@ -353,11 +393,12 @@ class PerDestinationQueue:
 
                 # dropping read receipts is a bit sad but should be solved
                 # through another mechanism, because this is all volatile!
-                self._pending_pdus = []
                 self._pending_edus = []
                 self._pending_edus_keyed = {}
                 self._pending_presence = {}
                 self._pending_rrs = {}
+
+            self._start_catching_up()
         except FederationDeniedError as e:
             logger.info(e)
         except HttpResponseException as e:
@@ -367,6 +408,8 @@ class PerDestinationQueue:
                 e.code,
                 e,
             )
+
+            self._start_catching_up()
         except RequestSendFailed as e:
             logger.warning(
                 "TX [%s] Failed to send transaction: %s", self._destination, e
@@ -376,16 +419,96 @@ class PerDestinationQueue:
                 logger.info(
                     "Failed to send event %s to %s", p.event_id, self._destination
                 )
+
+            self._start_catching_up()
         except Exception:
             logger.exception("TX [%s] Failed to send transaction", self._destination)
             for p in pending_pdus:
                 logger.info(
                     "Failed to send event %s to %s", p.event_id, self._destination
                 )
+
+            self._start_catching_up()
         finally:
             # We want to be *very* sure we clear this after we stop processing
             self.transmission_loop_running = False
 
+    async def _catch_up_transmission_loop(self) -> None:
+        first_catch_up_check = self._last_successful_stream_ordering is None
+
+        if first_catch_up_check:
+            # first catchup so get last_successful_stream_ordering from database
+            self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering(
+                self._destination
+            )
+
+        if self._last_successful_stream_ordering is None:
+            # if it's still None, then this means we don't have the information
+            # in our database ­ we haven't successfully sent a PDU to this server
+            # (at least since the introduction of the feature tracking
+            # last_successful_stream_ordering).
+            # Sadly, this means we can't do anything here as we don't know what
+            # needs catching up — so catching up is futile; let's stop.
+            self._catching_up = False
+            return
+
+        # get at most 50 catchup room/PDUs
+        while True:
+            event_ids = await self._store.get_catch_up_room_event_ids(
+                self._destination, self._last_successful_stream_ordering,
+            )
+
+            if not event_ids:
+                # No more events to catch up on, but we can't ignore the chance
+                # of a race condition, so we check that no new events have been
+                # skipped due to us being in catch-up mode
+
+                if self._catchup_last_skipped > self._last_successful_stream_ordering:
+                    # another event has been skipped because we were in catch-up mode
+                    continue
+
+                # we are done catching up!
+                self._catching_up = False
+                break
+
+            if first_catch_up_check:
+                # as this is our check for needing catch-up, we may have PDUs in
+                # the queue from before we *knew* we had to do catch-up, so
+                # clear those out now.
+                self._start_catching_up()
+
+            # fetch the relevant events from the event store
+            # - redacted behaviour of REDACT is fine, since we only send metadata
+            #   of redacted events to the destination.
+            # - don't need to worry about rejected events as we do not actively
+            #   forward received events over federation.
+            catchup_pdus = await self._store.get_events_as_list(event_ids)
+            if not catchup_pdus:
+                raise AssertionError(
+                    "No events retrieved when we asked for %r. "
+                    "This should not happen." % event_ids
+                )
+
+            if logger.isEnabledFor(logging.INFO):
+                rooms = [p.room_id for p in catchup_pdus]
+                logger.info("Catching up rooms to %s: %r", self._destination, rooms)
+
+            success = await self._transaction_manager.send_new_transaction(
+                self._destination, catchup_pdus, []
+            )
+
+            if not success:
+                return
+
+            sent_transactions_counter.inc()
+            final_pdu = catchup_pdus[-1]
+            self._last_successful_stream_ordering = cast(
+                int, final_pdu.internal_metadata.stream_ordering
+            )
+            await self._store.set_destination_last_successful_stream_ordering(
+                self._destination, self._last_successful_stream_ordering
+            )
+
     def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
         if not self._pending_rrs:
             return
@@ -446,3 +569,12 @@ class PerDestinationQueue:
         ]
 
         return (edus, stream_id)
+
+    def _start_catching_up(self) -> None:
+        """
+        Marks this destination as being in catch-up mode.
+
+        This throws away the PDU queue.
+        """
+        self._catching_up = True
+        self._pending_pdus = []
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index c84072ab73..3e07f925e0 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -15,6 +15,8 @@
 import logging
 from typing import TYPE_CHECKING, List
 
+from prometheus_client import Gauge
+
 from synapse.api.errors import HttpResponseException
 from synapse.events import EventBase
 from synapse.federation.persistence import TransactionActions
@@ -34,6 +36,12 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+last_pdu_age_metric = Gauge(
+    "synapse_federation_last_sent_pdu_age",
+    "The age (in seconds) of the last PDU successfully sent to the given domain",
+    labelnames=("server_name",),
+)
+
 
 class TransactionManager:
     """Helper class which handles building and sending transactions
@@ -48,6 +56,10 @@ class TransactionManager:
         self._transaction_actions = TransactionActions(self._store)
         self._transport_layer = hs.get_federation_transport_client()
 
+        self._federation_metrics_domains = (
+            hs.get_config().federation.federation_metrics_domains
+        )
+
         # HACK to get unique tx id
         self._next_txn_id = int(self.clock.time_msec())
 
@@ -119,6 +131,9 @@ class TransactionManager:
 
             # FIXME (erikj): This is a bit of a hack to make the Pdu age
             # keys work
+            # FIXME (richardv): I also believe it no longer works. We (now?) store
+            #  "age_ts" in "unsigned" rather than at the top level. See
+            #  https://github.com/matrix-org/synapse/issues/8429.
             def json_data_cb():
                 data = transaction.get_dict()
                 now = int(self.clock.time_msec())
@@ -167,5 +182,12 @@ class TransactionManager:
                     )
                 success = False
 
+            if success and pdus and destination in self._federation_metrics_domains:
+                last_pdu = pdus[-1]
+                last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
+                last_pdu_age_metric.labels(server_name=destination).set(
+                    last_pdu_age / 1000
+                )
+
             set_tag(tags.ERROR, not success)
             return success
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index cc7e9a973b..3a6b95631e 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -68,7 +68,7 @@ class TransportLayerServer(JsonResource):
         self.clock = hs.get_clock()
         self.servlet_groups = servlet_groups
 
-        super(TransportLayerServer, self).__init__(hs, canonical_json=False)
+        super().__init__(hs, canonical_json=False)
 
         self.authenticator = Authenticator(hs)
         self.ratelimiter = hs.get_federation_ratelimiter()
@@ -376,9 +376,7 @@ class FederationSendServlet(BaseFederationServlet):
     RATELIMIT = False
 
     def __init__(self, handler, server_name, **kwargs):
-        super(FederationSendServlet, self).__init__(
-            handler, server_name=server_name, **kwargs
-        )
+        super().__init__(handler, server_name=server_name, **kwargs)
         self.server_name = server_name
 
     # This is when someone is trying to send us a bunch of data.
@@ -773,9 +771,7 @@ class PublicRoomList(BaseFederationServlet):
     PATH = "/publicRooms"
 
     def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
-        super(PublicRoomList, self).__init__(
-            handler, authenticator, ratelimiter, server_name
-        )
+        super().__init__(handler, authenticator, ratelimiter, server_name)
         self.allow_access = allow_access
 
     async def on_GET(self, origin, content, query):