diff options
Diffstat (limited to 'synapse/util/ratelimitutils.py')
-rw-r--r-- | synapse/util/ratelimitutils.py | 34 |
1 files changed, 25 insertions, 9 deletions
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 2aceb1a47f..bd72947bfe 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -34,6 +34,7 @@ from prometheus_client.core import Counter from typing_extensions import ContextManager from twisted.internet import defer +from twisted.internet.interfaces import IReactorTime from synapse.api.errors import LimitExceededError from synapse.config.ratelimiting import FederationRatelimitSettings @@ -146,12 +147,14 @@ class FederationRateLimiter: def __init__( self, + reactor: IReactorTime, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, ): """ Args: + reactor clock config metrics_name: The name of the rate limiter so we can differentiate it @@ -163,7 +166,7 @@ class FederationRateLimiter: def new_limiter() -> "_PerHostRatelimiter": return _PerHostRatelimiter( - clock=clock, config=config, metrics_name=metrics_name + reactor=reactor, clock=clock, config=config, metrics_name=metrics_name ) self.ratelimiters: DefaultDict[ @@ -194,12 +197,14 @@ class FederationRateLimiter: class _PerHostRatelimiter: def __init__( self, + reactor: IReactorTime, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, ): """ Args: + reactor clock config metrics_name: The name of the rate limiter so we can differentiate it @@ -207,6 +212,7 @@ class _PerHostRatelimiter: for this rate limiter. from the rest in the metrics """ + self.reactor = reactor self.clock = clock self.metrics_name = metrics_name @@ -364,12 +370,22 @@ class _PerHostRatelimiter: def _on_exit(self, request_id: object) -> None: logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id)) - self.current_processing.discard(request_id) - try: - # start processing the next item on the queue. - _, deferred = self.ready_request_queue.popitem(last=False) - with PreserveLoggingContext(): - deferred.callback(None) - except KeyError: - pass + # When requests complete synchronously, we will recursively start the next + # request in the queue. To avoid stack exhaustion, we defer starting the next + # request until the next reactor tick. + + def start_next_request() -> None: + # We only remove the completed request from the list when we're about to + # start the next one, otherwise we can allow extra requests through. + self.current_processing.discard(request_id) + try: + # start processing the next item on the queue. + _, deferred = self.ready_request_queue.popitem(last=False) + + with PreserveLoggingContext(): + deferred.callback(None) + except KeyError: + pass + + self.reactor.callLater(0.0, start_next_request) |