summary refs log tree commit diff
path: root/synapse/util/ratelimitutils.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/ratelimitutils.py')
-rw-r--r--synapse/util/ratelimitutils.py34
1 files changed, 25 insertions, 9 deletions
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 2aceb1a47f..bd72947bfe 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -34,6 +34,7 @@ from prometheus_client.core import Counter
 from typing_extensions import ContextManager
 
 from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTime
 
 from synapse.api.errors import LimitExceededError
 from synapse.config.ratelimiting import FederationRatelimitSettings
@@ -146,12 +147,14 @@ class FederationRateLimiter:
 
     def __init__(
         self,
+        reactor: IReactorTime,
         clock: Clock,
         config: FederationRatelimitSettings,
         metrics_name: Optional[str] = None,
     ):
         """
         Args:
+            reactor
             clock
             config
             metrics_name: The name of the rate limiter so we can differentiate it
@@ -163,7 +166,7 @@ class FederationRateLimiter:
 
         def new_limiter() -> "_PerHostRatelimiter":
             return _PerHostRatelimiter(
-                clock=clock, config=config, metrics_name=metrics_name
+                reactor=reactor, clock=clock, config=config, metrics_name=metrics_name
             )
 
         self.ratelimiters: DefaultDict[
@@ -194,12 +197,14 @@ class FederationRateLimiter:
 class _PerHostRatelimiter:
     def __init__(
         self,
+        reactor: IReactorTime,
         clock: Clock,
         config: FederationRatelimitSettings,
         metrics_name: Optional[str] = None,
     ):
         """
         Args:
+            reactor
             clock
             config
             metrics_name: The name of the rate limiter so we can differentiate it
@@ -207,6 +212,7 @@ class _PerHostRatelimiter:
                 for this rate limiter.
                 from the rest in the metrics
         """
+        self.reactor = reactor
         self.clock = clock
         self.metrics_name = metrics_name
 
@@ -364,12 +370,22 @@ class _PerHostRatelimiter:
 
     def _on_exit(self, request_id: object) -> None:
         logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
-        self.current_processing.discard(request_id)
-        try:
-            # start processing the next item on the queue.
-            _, deferred = self.ready_request_queue.popitem(last=False)
 
-            with PreserveLoggingContext():
-                deferred.callback(None)
-        except KeyError:
-            pass
+        # When requests complete synchronously, we will recursively start the next
+        # request in the queue. To avoid stack exhaustion, we defer starting the next
+        # request until the next reactor tick.
+
+        def start_next_request() -> None:
+            # We only remove the completed request from the list when we're about to
+            # start the next one, otherwise we can allow extra requests through.
+            self.current_processing.discard(request_id)
+            try:
+                # start processing the next item on the queue.
+                _, deferred = self.ready_request_queue.popitem(last=False)
+
+                with PreserveLoggingContext():
+                    deferred.callback(None)
+            except KeyError:
+                pass
+
+        self.reactor.callLater(0.0, start_next_request)