diff options
Diffstat (limited to 'synapse/federation')
-rw-r--r-- | synapse/federation/federation_server.py | 52 | ||||
-rw-r--r-- | synapse/federation/transport/server.py | 13 |
2 files changed, 48 insertions, 17 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 218df884b0..ff00f0b302 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -97,10 +97,16 @@ class FederationServer(FederationBase): self.state = hs.get_state_handler() self.device_handler = hs.get_device_handler() + self._federation_ratelimiter = hs.get_federation_ratelimiter() self._server_linearizer = Linearizer("fed_server") self._transaction_linearizer = Linearizer("fed_txn_handler") + # We cache results for transaction with the same ID + self._transaction_resp_cache = ResponseCache( + hs, "fed_txn_handler", timeout_ms=30000 + ) + self.transaction_actions = TransactionActions(self.store) self.registry = hs.get_federation_registry() @@ -135,22 +141,44 @@ class FederationServer(FederationBase): request_time = self._clock.time_msec() transaction = Transaction(**transaction_data) + transaction_id = transaction.transaction_id # type: ignore - if not transaction.transaction_id: # type: ignore + if not transaction_id: raise Exception("Transaction missing transaction_id") - logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore + logger.debug("[%s] Got transaction", transaction_id) - # use a linearizer to ensure that we don't process the same transaction - # multiple times in parallel. - with ( - await self._transaction_linearizer.queue( - (origin, transaction.transaction_id) # type: ignore - ) - ): - result = await self._handle_incoming_transaction( - origin, transaction, request_time - ) + # We wrap in a ResponseCache so that we de-duplicate retried + # transactions. + return await self._transaction_resp_cache.wrap( + (origin, transaction_id), + self._on_incoming_transaction_inner, + origin, + transaction, + request_time, + ) + + async def _on_incoming_transaction_inner( + self, origin: str, transaction: Transaction, request_time: int + ) -> Tuple[int, Dict[str, Any]]: + # Use a linearizer to ensure that transactions from a remote are + # processed in order. + with await self._transaction_linearizer.queue(origin): + # We rate limit here *after* we've queued up the incoming requests, + # so that we don't fill up the ratelimiter with blocked requests. + # + # This is important as the ratelimiter allows N concurrent requests + # at a time, and only starts ratelimiting if there are more requests + # than that being processed at a time. If we queued up requests in + # the linearizer/response cache *after* the ratelimiting then those + # queued up requests would count as part of the allowed limit of N + # concurrent requests. + with self._federation_ratelimiter.ratelimit(origin) as d: + await d + + result = await self._handle_incoming_transaction( + origin, transaction, request_time + ) return result diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 9325e0f857..cc7e9a973b 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -45,7 +45,6 @@ from synapse.logging.opentracing import ( ) from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id -from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) @@ -72,9 +71,7 @@ class TransportLayerServer(JsonResource): super(TransportLayerServer, self).__init__(hs, canonical_json=False) self.authenticator = Authenticator(hs) - self.ratelimiter = FederationRateLimiter( - self.clock, config=hs.config.rc_federation - ) + self.ratelimiter = hs.get_federation_ratelimiter() self.register_servlets() @@ -272,6 +269,8 @@ class BaseFederationServlet: PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version + RATELIMIT = True # Whether to rate limit requests or not + def __init__(self, handler, authenticator, ratelimiter, server_name): self.handler = handler self.authenticator = authenticator @@ -335,7 +334,7 @@ class BaseFederationServlet: ) with scope: - if origin: + if origin and self.RATELIMIT: with ratelimiter.ratelimit(origin) as d: await d if request._disconnected: @@ -372,6 +371,10 @@ class BaseFederationServlet: class FederationSendServlet(BaseFederationServlet): PATH = "/send/(?P<transaction_id>[^/]*)/?" + # We ratelimit manually in the handler as we queue up the requests and we + # don't want to fill up the ratelimiter with blocked requests. + RATELIMIT = False + def __init__(self, handler, server_name, **kwargs): super(FederationSendServlet, self).__init__( handler, server_name=server_name, **kwargs |