summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2020-10-08 17:05:01 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2020-10-08 17:05:01 +0100
commit23b50d6fb881db23324cac9f64cba33a1d3747b3 (patch)
tree160ac2f45363782b0bbc77dc845f410b18585199 /synapse/federation
parentAdd `xyz.amorgan.knock` /versions string (diff)
parentMerge tag 'v1.21.0rc3' into develop (diff)
downloadsynapse-23b50d6fb881db23324cac9f64cba33a1d3747b3.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into soru/knock
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py14
-rw-r--r--synapse/federation/federation_server.py83
-rw-r--r--synapse/federation/sender/__init__.py53
-rw-r--r--synapse/federation/sender/per_destination_queue.py4
-rw-r--r--synapse/federation/sender/transaction_manager.py22
-rw-r--r--synapse/federation/transport/server.py23
6 files changed, 162 insertions, 37 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 063605eaff..c8936a28ea 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -26,10 +26,12 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Sequence,
     Tuple,
     TypeVar,
+    Union,
 )
 
 from prometheus_client import Counter
@@ -81,7 +83,7 @@ class InvalidResponseError(RuntimeError):
 
 class FederationClient(FederationBase):
     def __init__(self, hs):
-        super(FederationClient, self).__init__(hs)
+        super().__init__(hs)
 
         self.pdu_destination_tried = {}
         self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
@@ -219,11 +221,9 @@ class FederationClient(FederationBase):
             for p in transaction_data["pdus"]
         ]
 
-        # FIXME: We should handle signature failures more gracefully.
-        pdus[:] = await make_deferred_yieldable(
-            defer.gatherResults(
-                self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
-            ).addErrback(unwrapFirstError)
+        # Check signatures and hash of pdus, removing any from the list that fail checks
+        pdus[:] = await self._check_sigs_and_hash_and_fetch(
+            dest, pdus, outlier=True, room_version=room_version
         )
 
         return pdus
@@ -505,7 +505,7 @@ class FederationClient(FederationBase):
         user_id: str,
         membership: str,
         content: dict,
-        params: Dict[str, str],
+        params: Optional[Mapping[str, Union[str, Iterable[str]]]],
     ) -> Tuple[str, EventBase, RoomVersion]:
         """
         Creates an m.room.member event, with context, without participating in the room.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 662325bab1..6035d2f664 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,13 +22,12 @@ from typing import (
     Callable,
     Dict,
     List,
-    Match,
     Optional,
     Tuple,
     Union,
 )
 
-from prometheus_client import Counter, Histogram
+from prometheus_client import Counter, Gauge, Histogram
 
 from twisted.internet import defer
 from twisted.internet.abstract import isIPAddress
@@ -88,19 +87,32 @@ pdu_process_time = Histogram(
 )
 
 
+last_pdu_age_metric = Gauge(
+    "synapse_federation_last_received_pdu_age",
+    "The age (in seconds) of the last PDU successfully received from the given domain",
+    labelnames=("server_name",),
+)
+
+
 class FederationServer(FederationBase):
     def __init__(self, hs):
-        super(FederationServer, self).__init__(hs)
+        super().__init__(hs)
 
         self.auth = hs.get_auth()
         self.handler = hs.get_handlers().federation_handler
         self.state = hs.get_state_handler()
 
         self.device_handler = hs.get_device_handler()
+        self._federation_ratelimiter = hs.get_federation_ratelimiter()
 
         self._server_linearizer = Linearizer("fed_server")
         self._transaction_linearizer = Linearizer("fed_txn_handler")
 
+        # We cache results for transaction with the same ID
+        self._transaction_resp_cache = ResponseCache(
+            hs, "fed_txn_handler", timeout_ms=30000
+        )
+
         self.transaction_actions = TransactionActions(self.store)
 
         self.registry = hs.get_federation_registry()
@@ -112,6 +124,10 @@ class FederationServer(FederationBase):
             hs, "state_ids_resp", timeout_ms=30000
         )
 
+        self._federation_metrics_domains = (
+            hs.get_config().federation.federation_metrics_domains
+        )
+
     async def on_backfill_request(
         self, origin: str, room_id: str, versions: List[str], limit: int
     ) -> Tuple[int, Dict[str, Any]]:
@@ -135,22 +151,44 @@ class FederationServer(FederationBase):
         request_time = self._clock.time_msec()
 
         transaction = Transaction(**transaction_data)
+        transaction_id = transaction.transaction_id  # type: ignore
 
-        if not transaction.transaction_id:  # type: ignore
+        if not transaction_id:
             raise Exception("Transaction missing transaction_id")
 
-        logger.debug("[%s] Got transaction", transaction.transaction_id)  # type: ignore
+        logger.debug("[%s] Got transaction", transaction_id)
 
-        # use a linearizer to ensure that we don't process the same transaction
-        # multiple times in parallel.
-        with (
-            await self._transaction_linearizer.queue(
-                (origin, transaction.transaction_id)  # type: ignore
-            )
-        ):
-            result = await self._handle_incoming_transaction(
-                origin, transaction, request_time
-            )
+        # We wrap in a ResponseCache so that we de-duplicate retried
+        # transactions.
+        return await self._transaction_resp_cache.wrap(
+            (origin, transaction_id),
+            self._on_incoming_transaction_inner,
+            origin,
+            transaction,
+            request_time,
+        )
+
+    async def _on_incoming_transaction_inner(
+        self, origin: str, transaction: Transaction, request_time: int
+    ) -> Tuple[int, Dict[str, Any]]:
+        # Use a linearizer to ensure that transactions from a remote are
+        # processed in order.
+        with await self._transaction_linearizer.queue(origin):
+            # We rate limit here *after* we've queued up the incoming requests,
+            # so that we don't fill up the ratelimiter with blocked requests.
+            #
+            # This is important as the ratelimiter allows N concurrent requests
+            # at a time, and only starts ratelimiting if there are more requests
+            # than that being processed at a time. If we queued up requests in
+            # the linearizer/response cache *after* the ratelimiting then those
+            # queued up requests would count as part of the allowed limit of N
+            # concurrent requests.
+            with self._federation_ratelimiter.ratelimit(origin) as d:
+                await d
+
+                result = await self._handle_incoming_transaction(
+                    origin, transaction, request_time
+                )
 
         return result
 
@@ -234,7 +272,11 @@ class FederationServer(FederationBase):
 
         pdus_by_room = {}  # type: Dict[str, List[EventBase]]
 
+        newest_pdu_ts = 0
+
         for p in transaction.pdus:  # type: ignore
+            # FIXME (richardv): I don't think this works:
+            #  https://github.com/matrix-org/synapse/issues/8429
             if "unsigned" in p:
                 unsigned = p["unsigned"]
                 if "age" in unsigned:
@@ -272,6 +314,9 @@ class FederationServer(FederationBase):
             event = event_from_pdu_json(p, room_version)
             pdus_by_room.setdefault(room_id, []).append(event)
 
+            if event.origin_server_ts > newest_pdu_ts:
+                newest_pdu_ts = event.origin_server_ts
+
         pdu_results = {}
 
         # we can process different rooms in parallel (which is useful if they
@@ -312,6 +357,10 @@ class FederationServer(FederationBase):
             process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
         )
 
+        if newest_pdu_ts and origin in self._federation_metrics_domains:
+            newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
+            last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
+
         return pdu_results
 
     async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
@@ -801,14 +850,14 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
     return False
 
 
-def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
+def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
     if not isinstance(acl_entry, str):
         logger.warning(
             "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
         )
         return False
     regex = glob_to_regex(acl_entry)
-    return regex.match(server_name)
+    return bool(regex.match(server_name))
 
 
 class FederationHandlerRegistry:
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 41a726878d..e33b29a42c 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -55,6 +55,15 @@ sent_pdus_destination_dist_total = Counter(
     "Total number of PDUs queued for sending across all destinations",
 )
 
+# Time (in s) after Synapse's startup that we will begin to wake up destinations
+# that have catch-up outstanding.
+CATCH_UP_STARTUP_DELAY_SEC = 15
+
+# Time (in s) to wait in between waking up each destination, i.e. one destination
+# will be woken up every <x> seconds after Synapse's startup until we have woken
+# every destination has outstanding catch-up.
+CATCH_UP_STARTUP_INTERVAL_SEC = 5
+
 
 class FederationSender:
     def __init__(self, hs: "synapse.server.HomeServer"):
@@ -125,6 +134,14 @@ class FederationSender:
             1000.0 / hs.config.federation_rr_transactions_per_room_per_second
         )
 
+        # wake up destinations that have outstanding PDUs to be caught up
+        self._catchup_after_startup_timer = self.clock.call_later(
+            CATCH_UP_STARTUP_DELAY_SEC,
+            run_as_background_process,
+            "wake_destinations_needing_catchup",
+            self._wake_destinations_needing_catchup,
+        )
+
     def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
         """Get or create a PerDestinationQueue for the given destination
 
@@ -280,6 +297,8 @@ class FederationSender:
         sent_pdus_destination_dist_total.inc(len(destinations))
         sent_pdus_destination_dist_count.inc()
 
+        assert pdu.internal_metadata.stream_ordering
+
         # track the fact that we have a PDU for these destinations,
         # to allow us to perform catch-up later on if the remote is unreachable
         # for a while.
@@ -560,3 +579,37 @@ class FederationSender:
         # Dummy implementation for case where federation sender isn't offloaded
         # to a worker.
         return [], 0, False
+
+    async def _wake_destinations_needing_catchup(self):
+        """
+        Wakes up destinations that need catch-up and are not currently being
+        backed off from.
+
+        In order to reduce load spikes, adds a delay between each destination.
+        """
+
+        last_processed = None  # type: Optional[str]
+
+        while True:
+            destinations_to_wake = await self.store.get_catch_up_outstanding_destinations(
+                last_processed
+            )
+
+            if not destinations_to_wake:
+                # finished waking all destinations!
+                self._catchup_after_startup_timer = None
+                break
+
+            destinations_to_wake = [
+                d
+                for d in destinations_to_wake
+                if self._federation_shard_config.should_handle(self._instance_name, d)
+            ]
+
+            for last_processed in destinations_to_wake:
+                logger.info(
+                    "Destination %s has outstanding catch-up, waking up.",
+                    last_processed,
+                )
+                self.wake_destination(last_processed)
+                await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 2657767fd1..db8e456fe8 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -158,6 +158,7 @@ class PerDestinationQueue:
             # yet know if we have anything to catch up (None)
             self._pending_pdus.append(pdu)
         else:
+            assert pdu.internal_metadata.stream_ordering
             self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
 
         self.attempt_new_transaction()
@@ -361,6 +362,7 @@ class PerDestinationQueue:
                         last_successful_stream_ordering = (
                             final_pdu.internal_metadata.stream_ordering
                         )
+                        assert last_successful_stream_ordering
                         await self._store.set_destination_last_successful_stream_ordering(
                             self._destination, last_successful_stream_ordering
                         )
@@ -490,7 +492,7 @@ class PerDestinationQueue:
                 )
 
             if logger.isEnabledFor(logging.INFO):
-                rooms = (p.room_id for p in catchup_pdus)
+                rooms = [p.room_id for p in catchup_pdus]
                 logger.info("Catching up rooms to %s: %r", self._destination, rooms)
 
             success = await self._transaction_manager.send_new_transaction(
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index c84072ab73..3e07f925e0 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -15,6 +15,8 @@
 import logging
 from typing import TYPE_CHECKING, List
 
+from prometheus_client import Gauge
+
 from synapse.api.errors import HttpResponseException
 from synapse.events import EventBase
 from synapse.federation.persistence import TransactionActions
@@ -34,6 +36,12 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+last_pdu_age_metric = Gauge(
+    "synapse_federation_last_sent_pdu_age",
+    "The age (in seconds) of the last PDU successfully sent to the given domain",
+    labelnames=("server_name",),
+)
+
 
 class TransactionManager:
     """Helper class which handles building and sending transactions
@@ -48,6 +56,10 @@ class TransactionManager:
         self._transaction_actions = TransactionActions(self._store)
         self._transport_layer = hs.get_federation_transport_client()
 
+        self._federation_metrics_domains = (
+            hs.get_config().federation.federation_metrics_domains
+        )
+
         # HACK to get unique tx id
         self._next_txn_id = int(self.clock.time_msec())
 
@@ -119,6 +131,9 @@ class TransactionManager:
 
             # FIXME (erikj): This is a bit of a hack to make the Pdu age
             # keys work
+            # FIXME (richardv): I also believe it no longer works. We (now?) store
+            #  "age_ts" in "unsigned" rather than at the top level. See
+            #  https://github.com/matrix-org/synapse/issues/8429.
             def json_data_cb():
                 data = transaction.get_dict()
                 now = int(self.clock.time_msec())
@@ -167,5 +182,12 @@ class TransactionManager:
                     )
                 success = False
 
+            if success and pdus and destination in self._federation_metrics_domains:
+                last_pdu = pdus[-1]
+                last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
+                last_pdu_age_metric.labels(server_name=destination).set(
+                    last_pdu_age / 1000
+                )
+
             set_tag(tags.ERROR, not success)
             return success
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index e04704d10c..a2fb558b45 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -46,7 +46,6 @@ from synapse.logging.opentracing import (
 )
 from synapse.server import HomeServer
 from synapse.types import ThirdPartyInstanceID, get_domain_from_id
-from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger(__name__)
@@ -70,12 +69,10 @@ class TransportLayerServer(JsonResource):
         self.clock = hs.get_clock()
         self.servlet_groups = servlet_groups
 
-        super(TransportLayerServer, self).__init__(hs, canonical_json=False)
+        super().__init__(hs, canonical_json=False)
 
         self.authenticator = Authenticator(hs)
-        self.ratelimiter = FederationRateLimiter(
-            self.clock, config=hs.config.rc_federation
-        )
+        self.ratelimiter = hs.get_federation_ratelimiter()
 
         self.register_servlets()
 
@@ -273,6 +270,8 @@ class BaseFederationServlet:
 
     PREFIX = FEDERATION_V1_PREFIX  # Allows specifying the API version
 
+    RATELIMIT = True  # Whether to rate limit requests or not
+
     def __init__(self, handler, authenticator, ratelimiter, server_name):
         self.handler = handler
         self.authenticator = authenticator
@@ -336,7 +335,7 @@ class BaseFederationServlet:
                 )
 
             with scope:
-                if origin:
+                if origin and self.RATELIMIT:
                     with ratelimiter.ratelimit(origin) as d:
                         await d
                         if request._disconnected:
@@ -373,10 +372,12 @@ class BaseFederationServlet:
 class FederationSendServlet(BaseFederationServlet):
     PATH = "/send/(?P<transaction_id>[^/]*)/?"
 
+    # We ratelimit manually in the handler as we queue up the requests and we
+    # don't want to fill up the ratelimiter with blocked requests.
+    RATELIMIT = False
+
     def __init__(self, handler, server_name, **kwargs):
-        super(FederationSendServlet, self).__init__(
-            handler, server_name=server_name, **kwargs
-        )
+        super().__init__(handler, server_name=server_name, **kwargs)
         self.server_name = server_name
 
     # This is when someone is trying to send us a bunch of data.
@@ -787,9 +788,7 @@ class PublicRoomList(BaseFederationServlet):
     PATH = "/publicRooms"
 
     def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
-        super(PublicRoomList, self).__init__(
-            handler, authenticator, ratelimiter, server_name
-        )
+        super().__init__(handler, authenticator, ratelimiter, server_name)
         self.allow_access = allow_access
 
     async def on_GET(self, origin, content, query):