summary refs log tree commit diff
path: root/synapse/util/ratelimitutils.py
blob: dc9bddb00d5319d66f40ff0d76354fda992d4af4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#

import collections
import contextlib
import logging
import threading
import typing
from typing import (
    Any,
    Callable,
    ContextManager,
    DefaultDict,
    Dict,
    Iterator,
    List,
    Mapping,
    MutableSet,
    Optional,
    Set,
    Tuple,
)
from weakref import WeakSet

from prometheus_client.core import Counter

from twisted.internet import defer

from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.logging.context import (
    PreserveLoggingContext,
    make_deferred_yieldable,
    run_in_background,
)
from synapse.logging.opentracing import start_active_span
from synapse.metrics import Histogram, LaterGauge
from synapse.util import Clock

if typing.TYPE_CHECKING:
    from contextlib import _GeneratorContextManager

logger = logging.getLogger(__name__)


# Track how much the ratelimiter is affecting requests
rate_limit_sleep_counter = Counter(
    "synapse_rate_limit_sleep",
    "Number of requests slept by the rate limiter",
    ["rate_limiter_name"],
)
rate_limit_reject_counter = Counter(
    "synapse_rate_limit_reject",
    "Number of requests rejected by the rate limiter",
    ["rate_limiter_name"],
)
queue_wait_timer = Histogram(
    "synapse_rate_limit_queue_wait_time_seconds",
    "Amount of time spent waiting for the rate limiter to let our request through.",
    ["rate_limiter_name"],
    buckets=(
        0.005,
        0.01,
        0.025,
        0.05,
        0.1,
        0.25,
        0.5,
        0.75,
        1.0,
        2.5,
        5.0,
        10.0,
        20.0,
        "+Inf",
    ),
)


# This must be a `WeakSet`, otherwise we indirectly hold on to entire `HomeServer`s
# during trial test runs and leak a lot of memory.
_rate_limiter_instances: MutableSet["FederationRateLimiter"] = WeakSet()
# Protects the _rate_limiter_instances set from concurrent access
_rate_limiter_instances_lock = threading.Lock()


def _get_counts_from_rate_limiter_instance(
    count_func: Callable[["FederationRateLimiter"], int]
) -> Mapping[Tuple[str, ...], int]:
    """Returns a count of something (slept/rejected hosts) by (metrics_name)"""
    # Cast to a list to prevent it changing while the Prometheus
    # thread is collecting metrics
    with _rate_limiter_instances_lock:
        rate_limiter_instances = list(_rate_limiter_instances)

    # Map from (metrics_name,) -> int, the number of something like slept hosts
    # or rejected hosts. The key type is Tuple[str], but we leave the length
    # unspecified for compatability with LaterGauge's annotations.
    counts: Dict[Tuple[str, ...], int] = {}
    for rate_limiter_instance in rate_limiter_instances:
        # Only track metrics if they provided a `metrics_name` to
        # differentiate this instance of the rate limiter.
        if rate_limiter_instance.metrics_name:
            key = (rate_limiter_instance.metrics_name,)
            counts[key] = count_func(rate_limiter_instance)

    return counts


# We track the number of affected hosts per time-period so we can
# differentiate one really noisy homeserver from a general
# ratelimit tuning problem across the federation.
LaterGauge(
    "synapse_rate_limit_sleep_affected_hosts",
    "Number of hosts that had requests put to sleep",
    ["rate_limiter_name"],
    lambda: _get_counts_from_rate_limiter_instance(
        lambda rate_limiter_instance: sum(
            ratelimiter.should_sleep()
            for ratelimiter in rate_limiter_instance.ratelimiters.values()
        )
    ),
)
LaterGauge(
    "synapse_rate_limit_reject_affected_hosts",
    "Number of hosts that had requests rejected",
    ["rate_limiter_name"],
    lambda: _get_counts_from_rate_limiter_instance(
        lambda rate_limiter_instance: sum(
            ratelimiter.should_reject()
            for ratelimiter in rate_limiter_instance.ratelimiters.values()
        )
    ),
)


class FederationRateLimiter:
    """Used to rate limit request per-host."""

    def __init__(
        self,
        clock: Clock,
        config: FederationRatelimitSettings,
        metrics_name: Optional[str] = None,
    ):
        """
        Args:
            clock
            config
            metrics_name: The name of the rate limiter so we can differentiate it
                from the rest in the metrics. If `None`, we don't track metrics
                for this rate limiter.

        """
        self.metrics_name = metrics_name

        def new_limiter() -> "_PerHostRatelimiter":
            return _PerHostRatelimiter(
                clock=clock, config=config, metrics_name=metrics_name
            )

        self.ratelimiters: DefaultDict[
            str, "_PerHostRatelimiter"
        ] = collections.defaultdict(new_limiter)

        with _rate_limiter_instances_lock:
            _rate_limiter_instances.add(self)

    def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
        """Used to ratelimit an incoming request from a given host

        Example usage:

            with rate_limiter.ratelimit(origin) as wait_deferred:
                yield wait_deferred
                # Handle request ...

        Args:
            host: Origin of incoming request.

        Returns:
            context manager which returns a deferred.
        """
        return self.ratelimiters[host].ratelimit(host)


class _PerHostRatelimiter:
    def __init__(
        self,
        clock: Clock,
        config: FederationRatelimitSettings,
        metrics_name: Optional[str] = None,
    ):
        """
        Args:
            clock
            config
            metrics_name: The name of the rate limiter so we can differentiate it
                from the rest in the metrics. If `None`, we don't track metrics
                for this rate limiter.
                from the rest in the metrics
        """
        self.clock = clock
        self.metrics_name = metrics_name

        self.window_size = config.window_size
        self.sleep_limit = config.sleep_limit
        self.sleep_sec = config.sleep_delay / 1000.0
        self.reject_limit = config.reject_limit
        self.concurrent_requests = config.concurrent

        # request_id objects for requests which have been slept
        self.sleeping_requests: Set[object] = set()

        # map from request_id object to Deferred for requests which are ready
        # for processing but have been queued
        self.ready_request_queue: collections.OrderedDict[
            object, defer.Deferred[None]
        ] = collections.OrderedDict()

        # request id objects for requests which are in progress
        self.current_processing: Set[object] = set()

        # times at which we have recently (within the last window_size ms)
        # received requests.
        self.request_times: List[int] = []

    @contextlib.contextmanager
    def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]":
        # `contextlib.contextmanager` takes a generator and turns it into a
        # context manager. The generator should only yield once with a value
        # to be returned by manager.
        # Exceptions will be reraised at the yield.

        self.host = host

        request_id = object()
        # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
        # type-checking, but we'd need Twisted >= 21.2.
        ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id))
        try:
            yield ret
        finally:
            self._on_exit(request_id)

    def should_reject(self) -> bool:
        """
        Whether to reject the request if we already have too many queued up
        (either sleeping or in the ready queue).
        """
        queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
        return queue_size > self.reject_limit

    def should_sleep(self) -> bool:
        """
        Whether to sleep the request if we already have too many requests coming
        through within the window.
        """
        return len(self.request_times) > self.sleep_limit

    async def _on_enter_with_tracing(self, request_id: object) -> None:
        maybe_metrics_cm: ContextManager = contextlib.nullcontext()
        if self.metrics_name:
            maybe_metrics_cm = queue_wait_timer.labels(self.metrics_name).time()
        with start_active_span("ratelimit wait"), maybe_metrics_cm:
            await self._on_enter(request_id)

    def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
        time_now = self.clock.time_msec()

        # remove any entries from request_times which aren't within the window
        self.request_times[:] = [
            r for r in self.request_times if time_now - r < self.window_size
        ]

        # reject the request if we already have too many queued up (either
        # sleeping or in the ready queue).
        if self.should_reject():
            logger.debug("Ratelimiter(%s): rejecting request", self.host)
            if self.metrics_name:
                rate_limit_reject_counter.labels(self.metrics_name).inc()
            raise LimitExceededError(
                limiter_name="rc_federation",
                retry_after_ms=int(self.window_size / self.sleep_limit),
            )

        self.request_times.append(time_now)

        def queue_request() -> "defer.Deferred[None]":
            if len(self.current_processing) >= self.concurrent_requests:
                queue_defer: defer.Deferred[None] = defer.Deferred()
                self.ready_request_queue[request_id] = queue_defer
                logger.info(
                    "Ratelimiter(%s): queueing request (queue now %i items)",
                    self.host,
                    len(self.ready_request_queue),
                )

                return queue_defer
            else:
                return defer.succeed(None)

        logger.debug(
            "Ratelimit(%s) [%s]: len(self.request_times)=%d",
            self.host,
            id(request_id),
            len(self.request_times),
        )

        if self.should_sleep():
            logger.debug(
                "Ratelimiter(%s) [%s]: sleeping request for %f sec",
                self.host,
                id(request_id),
                self.sleep_sec,
            )
            if self.metrics_name:
                rate_limit_sleep_counter.labels(self.metrics_name).inc()
            ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)

            self.sleeping_requests.add(request_id)

            def on_wait_finished(_: Any) -> "defer.Deferred[None]":
                logger.debug(
                    "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id)
                )
                self.sleeping_requests.discard(request_id)
                queue_defer = queue_request()
                return queue_defer

            ret_defer.addBoth(on_wait_finished)
        else:
            ret_defer = queue_request()

        def on_start(r: object) -> object:
            logger.debug(
                "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id)
            )
            self.current_processing.add(request_id)
            return r

        def on_err(r: object) -> object:
            # XXX: why is this necessary? this is called before we start
            # processing the request so why would the request be in
            # current_processing?
            self.current_processing.discard(request_id)
            return r

        def on_both(r: object) -> object:
            # Ensure that we've properly cleaned up.
            self.sleeping_requests.discard(request_id)
            self.ready_request_queue.pop(request_id, None)
            return r

        ret_defer.addCallbacks(on_start, on_err)
        ret_defer.addBoth(on_both)
        return make_deferred_yieldable(ret_defer)

    def _on_exit(self, request_id: object) -> None:
        logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))

        # When requests complete synchronously, we will recursively start the next
        # request in the queue. To avoid stack exhaustion, we defer starting the next
        # request until the next reactor tick.

        def start_next_request() -> None:
            # We only remove the completed request from the list when we're about to
            # start the next one, otherwise we can allow extra requests through.
            self.current_processing.discard(request_id)
            try:
                # start processing the next item on the queue.
                _, deferred = self.ready_request_queue.popitem(last=False)

                with PreserveLoggingContext():
                    deferred.callback(None)
            except KeyError:
                pass

        self.clock.call_later(0.0, start_next_request)