diff options
Diffstat (limited to 'synapse/util/ratelimitutils.py')
-rw-r--r-- | synapse/util/ratelimitutils.py | 214 |
1 files changed, 197 insertions, 17 deletions
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 6394cc39ac..2aceb1a47f 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -15,8 +15,23 @@ import collections import contextlib import logging +import threading import typing -from typing import Any, DefaultDict, Iterator, List, Set +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Iterator, + List, + Mapping, + Optional, + Set, + Tuple, +) + +from prometheus_client.core import Counter +from typing_extensions import ContextManager from twisted.internet import defer @@ -27,6 +42,8 @@ from synapse.logging.context import ( make_deferred_yieldable, run_in_background, ) +from synapse.logging.opentracing import start_active_span +from synapse.metrics import Histogram, LaterGauge from synapse.util import Clock if typing.TYPE_CHECKING: @@ -35,15 +52,127 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) +# Track how much the ratelimiter is affecting requests +rate_limit_sleep_counter = Counter( + "synapse_rate_limit_sleep", + "Number of requests slept by the rate limiter", + ["rate_limiter_name"], +) +rate_limit_reject_counter = Counter( + "synapse_rate_limit_reject", + "Number of requests rejected by the rate limiter", + ["rate_limiter_name"], +) +queue_wait_timer = Histogram( + "synapse_rate_limit_queue_wait_time_seconds", + "Amount of time spent waiting for the rate limiter to let our request through.", + ["rate_limiter_name"], + buckets=( + 0.005, + 0.01, + 0.025, + 0.05, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 10.0, + 20.0, + "+Inf", + ), +) + + +_rate_limiter_instances: Set["FederationRateLimiter"] = set() +# Protects the _rate_limiter_instances set from concurrent access +_rate_limiter_instances_lock = threading.Lock() + + +def _get_counts_from_rate_limiter_instance( + count_func: Callable[["FederationRateLimiter"], int] +) -> Mapping[Tuple[str, ...], int]: + """Returns a count of something (slept/rejected hosts) by (metrics_name)""" + # Cast to a list to prevent it changing while the Prometheus + # thread is collecting metrics + with _rate_limiter_instances_lock: + rate_limiter_instances = list(_rate_limiter_instances) + + # Map from (metrics_name,) -> int, the number of something like slept hosts + # or rejected hosts. The key type is Tuple[str], but we leave the length + # unspecified for compatability with LaterGauge's annotations. + counts: Dict[Tuple[str, ...], int] = {} + for rate_limiter_instance in rate_limiter_instances: + # Only track metrics if they provided a `metrics_name` to + # differentiate this instance of the rate limiter. + if rate_limiter_instance.metrics_name: + key = (rate_limiter_instance.metrics_name,) + counts[key] = count_func(rate_limiter_instance) + + return counts + + +# We track the number of affected hosts per time-period so we can +# differentiate one really noisy homeserver from a general +# ratelimit tuning problem across the federation. +LaterGauge( + "synapse_rate_limit_sleep_affected_hosts", + "Number of hosts that had requests put to sleep", + ["rate_limiter_name"], + lambda: _get_counts_from_rate_limiter_instance( + lambda rate_limiter_instance: sum( + ratelimiter.should_sleep() + for ratelimiter in rate_limiter_instance.ratelimiters.values() + ) + ), +) +LaterGauge( + "synapse_rate_limit_reject_affected_hosts", + "Number of hosts that had requests rejected", + ["rate_limiter_name"], + lambda: _get_counts_from_rate_limiter_instance( + lambda rate_limiter_instance: sum( + ratelimiter.should_reject() + for ratelimiter in rate_limiter_instance.ratelimiters.values() + ) + ), +) + + class FederationRateLimiter: - def __init__(self, clock: Clock, config: FederationRatelimitSettings): + """Used to rate limit request per-host.""" + + def __init__( + self, + clock: Clock, + config: FederationRatelimitSettings, + metrics_name: Optional[str] = None, + ): + """ + Args: + clock + config + metrics_name: The name of the rate limiter so we can differentiate it + from the rest in the metrics. If `None`, we don't track metrics + for this rate limiter. + + """ + self.metrics_name = metrics_name + def new_limiter() -> "_PerHostRatelimiter": - return _PerHostRatelimiter(clock=clock, config=config) + return _PerHostRatelimiter( + clock=clock, config=config, metrics_name=metrics_name + ) self.ratelimiters: DefaultDict[ str, "_PerHostRatelimiter" ] = collections.defaultdict(new_limiter) + with _rate_limiter_instances_lock: + _rate_limiter_instances.add(self) + def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]": """Used to ratelimit an incoming request from a given host @@ -54,22 +183,32 @@ class FederationRateLimiter: # Handle request ... Args: - host (str): Origin of incoming request. + host: Origin of incoming request. Returns: context manager which returns a deferred. """ - return self.ratelimiters[host].ratelimit() + return self.ratelimiters[host].ratelimit(host) class _PerHostRatelimiter: - def __init__(self, clock: Clock, config: FederationRatelimitSettings): + def __init__( + self, + clock: Clock, + config: FederationRatelimitSettings, + metrics_name: Optional[str] = None, + ): """ Args: clock config + metrics_name: The name of the rate limiter so we can differentiate it + from the rest in the metrics. If `None`, we don't track metrics + for this rate limiter. + from the rest in the metrics """ self.clock = clock + self.metrics_name = metrics_name self.window_size = config.window_size self.sleep_limit = config.sleep_limit @@ -94,19 +233,45 @@ class _PerHostRatelimiter: self.request_times: List[int] = [] @contextlib.contextmanager - def ratelimit(self) -> "Iterator[defer.Deferred[None]]": + def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]": # `contextlib.contextmanager` takes a generator and turns it into a # context manager. The generator should only yield once with a value # to be returned by manager. # Exceptions will be reraised at the yield. + self.host = host + request_id = object() - ret = self._on_enter(request_id) + # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant + # type-checking, but we'd need Twisted >= 21.2. + ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id)) try: yield ret finally: self._on_exit(request_id) + def should_reject(self) -> bool: + """ + Whether to reject the request if we already have too many queued up + (either sleeping or in the ready queue). + """ + queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) + return queue_size > self.reject_limit + + def should_sleep(self) -> bool: + """ + Whether to sleep the request if we already have too many requests coming + through within the window. + """ + return len(self.request_times) > self.sleep_limit + + async def _on_enter_with_tracing(self, request_id: object) -> None: + maybe_metrics_cm: ContextManager = contextlib.nullcontext() + if self.metrics_name: + maybe_metrics_cm = queue_wait_timer.labels(self.metrics_name).time() + with start_active_span("ratelimit wait"), maybe_metrics_cm: + await self._on_enter(request_id) + def _on_enter(self, request_id: object) -> "defer.Deferred[None]": time_now = self.clock.time_msec() @@ -117,8 +282,10 @@ class _PerHostRatelimiter: # reject the request if we already have too many queued up (either # sleeping or in the ready queue). - queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) - if queue_size > self.reject_limit: + if self.should_reject(): + logger.debug("Ratelimiter(%s): rejecting request", self.host) + if self.metrics_name: + rate_limit_reject_counter.labels(self.metrics_name).inc() raise LimitExceededError( retry_after_ms=int(self.window_size / self.sleep_limit) ) @@ -130,7 +297,8 @@ class _PerHostRatelimiter: queue_defer: defer.Deferred[None] = defer.Deferred() self.ready_request_queue[request_id] = queue_defer logger.info( - "Ratelimiter: queueing request (queue now %i items)", + "Ratelimiter(%s): queueing request (queue now %i items)", + self.host, len(self.ready_request_queue), ) @@ -139,19 +307,29 @@ class _PerHostRatelimiter: return defer.succeed(None) logger.debug( - "Ratelimit [%s]: len(self.request_times)=%d", + "Ratelimit(%s) [%s]: len(self.request_times)=%d", + self.host, id(request_id), len(self.request_times), ) - if len(self.request_times) > self.sleep_limit: - logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec) + if self.should_sleep(): + logger.debug( + "Ratelimiter(%s) [%s]: sleeping request for %f sec", + self.host, + id(request_id), + self.sleep_sec, + ) + if self.metrics_name: + rate_limit_sleep_counter.labels(self.metrics_name).inc() ret_defer = run_in_background(self.clock.sleep, self.sleep_sec) self.sleeping_requests.add(request_id) def on_wait_finished(_: Any) -> "defer.Deferred[None]": - logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) + logger.debug( + "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id) + ) self.sleeping_requests.discard(request_id) queue_defer = queue_request() return queue_defer @@ -161,7 +339,9 @@ class _PerHostRatelimiter: ret_defer = queue_request() def on_start(r: object) -> object: - logger.debug("Ratelimit [%s]: Processing req", id(request_id)) + logger.debug( + "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id) + ) self.current_processing.add(request_id) return r @@ -183,7 +363,7 @@ class _PerHostRatelimiter: return make_deferred_yieldable(ret_defer) def _on_exit(self, request_id: object) -> None: - logger.debug("Ratelimit [%s]: Processed req", id(request_id)) + logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id)) self.current_processing.discard(request_id) try: # start processing the next item on the queue. |