diff --git a/changelog.d/8342.bugfix b/changelog.d/8342.bugfix
new file mode 100644
index 0000000000..786057facb
--- /dev/null
+++ b/changelog.d/8342.bugfix
@@ -0,0 +1 @@
+Fix ratelimitng of federation `/send` requests.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 218df884b0..ff00f0b302 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -97,10 +97,16 @@ class FederationServer(FederationBase):
self.state = hs.get_state_handler()
self.device_handler = hs.get_device_handler()
+ self._federation_ratelimiter = hs.get_federation_ratelimiter()
self._server_linearizer = Linearizer("fed_server")
self._transaction_linearizer = Linearizer("fed_txn_handler")
+ # We cache results for transaction with the same ID
+ self._transaction_resp_cache = ResponseCache(
+ hs, "fed_txn_handler", timeout_ms=30000
+ )
+
self.transaction_actions = TransactionActions(self.store)
self.registry = hs.get_federation_registry()
@@ -135,22 +141,44 @@ class FederationServer(FederationBase):
request_time = self._clock.time_msec()
transaction = Transaction(**transaction_data)
+ transaction_id = transaction.transaction_id # type: ignore
- if not transaction.transaction_id: # type: ignore
+ if not transaction_id:
raise Exception("Transaction missing transaction_id")
- logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore
+ logger.debug("[%s] Got transaction", transaction_id)
- # use a linearizer to ensure that we don't process the same transaction
- # multiple times in parallel.
- with (
- await self._transaction_linearizer.queue(
- (origin, transaction.transaction_id) # type: ignore
- )
- ):
- result = await self._handle_incoming_transaction(
- origin, transaction, request_time
- )
+ # We wrap in a ResponseCache so that we de-duplicate retried
+ # transactions.
+ return await self._transaction_resp_cache.wrap(
+ (origin, transaction_id),
+ self._on_incoming_transaction_inner,
+ origin,
+ transaction,
+ request_time,
+ )
+
+ async def _on_incoming_transaction_inner(
+ self, origin: str, transaction: Transaction, request_time: int
+ ) -> Tuple[int, Dict[str, Any]]:
+ # Use a linearizer to ensure that transactions from a remote are
+ # processed in order.
+ with await self._transaction_linearizer.queue(origin):
+ # We rate limit here *after* we've queued up the incoming requests,
+ # so that we don't fill up the ratelimiter with blocked requests.
+ #
+ # This is important as the ratelimiter allows N concurrent requests
+ # at a time, and only starts ratelimiting if there are more requests
+ # than that being processed at a time. If we queued up requests in
+ # the linearizer/response cache *after* the ratelimiting then those
+ # queued up requests would count as part of the allowed limit of N
+ # concurrent requests.
+ with self._federation_ratelimiter.ratelimit(origin) as d:
+ await d
+
+ result = await self._handle_incoming_transaction(
+ origin, transaction, request_time
+ )
return result
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 9325e0f857..cc7e9a973b 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -45,7 +45,6 @@ from synapse.logging.opentracing import (
)
from synapse.server import HomeServer
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
-from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
@@ -72,9 +71,7 @@ class TransportLayerServer(JsonResource):
super(TransportLayerServer, self).__init__(hs, canonical_json=False)
self.authenticator = Authenticator(hs)
- self.ratelimiter = FederationRateLimiter(
- self.clock, config=hs.config.rc_federation
- )
+ self.ratelimiter = hs.get_federation_ratelimiter()
self.register_servlets()
@@ -272,6 +269,8 @@ class BaseFederationServlet:
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
+ RATELIMIT = True # Whether to rate limit requests or not
+
def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
self.authenticator = authenticator
@@ -335,7 +334,7 @@ class BaseFederationServlet:
)
with scope:
- if origin:
+ if origin and self.RATELIMIT:
with ratelimiter.ratelimit(origin) as d:
await d
if request._disconnected:
@@ -372,6 +371,10 @@ class BaseFederationServlet:
class FederationSendServlet(BaseFederationServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/?"
+ # We ratelimit manually in the handler as we queue up the requests and we
+ # don't want to fill up the ratelimiter with blocked requests.
+ RATELIMIT = False
+
def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__(
handler, server_name=server_name, **kwargs
diff --git a/synapse/server.py b/synapse/server.py
index 9055b97ac3..5e3752c333 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -114,6 +114,7 @@ from synapse.streams.events import EventSources
from synapse.types import DomainSpecificString
from synapse.util import Clock
from synapse.util.distributor import Distributor
+from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
@@ -642,6 +643,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_replication_streams(self) -> Dict[str, Stream]:
return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
+ @cache_in_self
+ def get_federation_ratelimiter(self) -> FederationRateLimiter:
+ return FederationRateLimiter(self.clock, config=self.config.rc_federation)
+
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|