summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-09-18 10:50:04 +0100
committerErik Johnston <erik@matrix.org>2020-09-18 10:50:04 +0100
commit5e42e61609d2eb35092e31bfc9a0ded9b67f3510 (patch)
tree8fa9ea6966cfdc76bd47ae9af0a0ebc8e18dd87e /synapse/federation
parentMove lint dependencies to extras_require (#8330) (diff)
parentFix ratelimiting for federation `/send` requests. (#8342) (diff)
downloadsynapse-5e42e61609d2eb35092e31bfc9a0ded9b67f3510.tar.xz
Merge remote-tracking branch 'origin/release-v1.20.0' into develop
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_server.py52
-rw-r--r--synapse/federation/transport/server.py13
2 files changed, 48 insertions, 17 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 218df884b0..ff00f0b302 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -97,10 +97,16 @@ class FederationServer(FederationBase):
         self.state = hs.get_state_handler()
 
         self.device_handler = hs.get_device_handler()
+        self._federation_ratelimiter = hs.get_federation_ratelimiter()
 
         self._server_linearizer = Linearizer("fed_server")
         self._transaction_linearizer = Linearizer("fed_txn_handler")
 
+        # We cache results for transaction with the same ID
+        self._transaction_resp_cache = ResponseCache(
+            hs, "fed_txn_handler", timeout_ms=30000
+        )
+
         self.transaction_actions = TransactionActions(self.store)
 
         self.registry = hs.get_federation_registry()
@@ -135,22 +141,44 @@ class FederationServer(FederationBase):
         request_time = self._clock.time_msec()
 
         transaction = Transaction(**transaction_data)
+        transaction_id = transaction.transaction_id  # type: ignore
 
-        if not transaction.transaction_id:  # type: ignore
+        if not transaction_id:
             raise Exception("Transaction missing transaction_id")
 
-        logger.debug("[%s] Got transaction", transaction.transaction_id)  # type: ignore
+        logger.debug("[%s] Got transaction", transaction_id)
 
-        # use a linearizer to ensure that we don't process the same transaction
-        # multiple times in parallel.
-        with (
-            await self._transaction_linearizer.queue(
-                (origin, transaction.transaction_id)  # type: ignore
-            )
-        ):
-            result = await self._handle_incoming_transaction(
-                origin, transaction, request_time
-            )
+        # We wrap in a ResponseCache so that we de-duplicate retried
+        # transactions.
+        return await self._transaction_resp_cache.wrap(
+            (origin, transaction_id),
+            self._on_incoming_transaction_inner,
+            origin,
+            transaction,
+            request_time,
+        )
+
+    async def _on_incoming_transaction_inner(
+        self, origin: str, transaction: Transaction, request_time: int
+    ) -> Tuple[int, Dict[str, Any]]:
+        # Use a linearizer to ensure that transactions from a remote are
+        # processed in order.
+        with await self._transaction_linearizer.queue(origin):
+            # We rate limit here *after* we've queued up the incoming requests,
+            # so that we don't fill up the ratelimiter with blocked requests.
+            #
+            # This is important as the ratelimiter allows N concurrent requests
+            # at a time, and only starts ratelimiting if there are more requests
+            # than that being processed at a time. If we queued up requests in
+            # the linearizer/response cache *after* the ratelimiting then those
+            # queued up requests would count as part of the allowed limit of N
+            # concurrent requests.
+            with self._federation_ratelimiter.ratelimit(origin) as d:
+                await d
+
+                result = await self._handle_incoming_transaction(
+                    origin, transaction, request_time
+                )
 
         return result
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 9325e0f857..cc7e9a973b 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -45,7 +45,6 @@ from synapse.logging.opentracing import (
 )
 from synapse.server import HomeServer
 from synapse.types import ThirdPartyInstanceID, get_domain_from_id
-from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger(__name__)
@@ -72,9 +71,7 @@ class TransportLayerServer(JsonResource):
         super(TransportLayerServer, self).__init__(hs, canonical_json=False)
 
         self.authenticator = Authenticator(hs)
-        self.ratelimiter = FederationRateLimiter(
-            self.clock, config=hs.config.rc_federation
-        )
+        self.ratelimiter = hs.get_federation_ratelimiter()
 
         self.register_servlets()
 
@@ -272,6 +269,8 @@ class BaseFederationServlet:
 
     PREFIX = FEDERATION_V1_PREFIX  # Allows specifying the API version
 
+    RATELIMIT = True  # Whether to rate limit requests or not
+
     def __init__(self, handler, authenticator, ratelimiter, server_name):
         self.handler = handler
         self.authenticator = authenticator
@@ -335,7 +334,7 @@ class BaseFederationServlet:
                 )
 
             with scope:
-                if origin:
+                if origin and self.RATELIMIT:
                     with ratelimiter.ratelimit(origin) as d:
                         await d
                         if request._disconnected:
@@ -372,6 +371,10 @@ class BaseFederationServlet:
 class FederationSendServlet(BaseFederationServlet):
     PATH = "/send/(?P<transaction_id>[^/]*)/?"
 
+    # We ratelimit manually in the handler as we queue up the requests and we
+    # don't want to fill up the ratelimiter with blocked requests.
+    RATELIMIT = False
+
     def __init__(self, handler, server_name, **kwargs):
         super(FederationSendServlet, self).__init__(
             handler, server_name=server_name, **kwargs