summary refs log tree commit diff
path: root/synapse/api/ratelimiting.py
blob: 79b7631172bca3923210601febadb9af62cbb920 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from typing import Any, Optional, Tuple

from synapse.api.errors import LimitExceededError


class Ratelimiter(object):
    """
    Ratelimit actions marked by arbitrary keys.

    Args:
        rate_hz: The long term number of actions that can be performed in a second.
        burst_count: How many actions that can be performed before being limited.
    """

    def __init__(self, rate_hz: float, burst_count: int):
        # A ordered dictionary keeping track of actions, when they were last
        # performed and how often. Each entry is a mapping from a key of arbitrary type
        # to a tuple representing:
        #   * How many times an action has occurred since a point in time
        #   * That point in time
        self.actions = OrderedDict()  # type: OrderedDict[Any, Tuple[float, int]]
        self.rate_hz = rate_hz
        self.burst_count = burst_count

    def can_do_action(
        self,
        key: Any,
        time_now_s: int,
        update: bool = True,
        rate_hz: Optional[float] = None,
        burst_count: Optional[int] = None,
    ) -> Tuple[bool, float]:
        """Can the entity (e.g. user or IP address) perform the action?

        Args:
            key: The key we should use when rate limiting. Can be a user ID
                (when sending events), an IP address, etc.
            time_now_s: The time now
            update: Whether to count this check as performing the action
            rate_hz: The long term number of actions that can be performed in a second.
                Overrides the value set during instantiation if set.
            burst_count: How many actions that can be performed before being limited.
                Overrides the value set during instantiation if set.

        Returns:
            A tuple containing:
                * A bool indicating if they can perform the action now
                * The time in seconds of when it can next be performed.
                  -1 if a rate_hz has not been defined for this Ratelimiter
        """
        # Override default values if set
        rate_hz = rate_hz or self.rate_hz
        burst_count = burst_count or self.burst_count

        # Remove any expired entries
        self._prune_message_counts(time_now_s, rate_hz)

        # Check if there is an existing count entry for this key
        action_count, time_start, = self.actions.get(key, (0.0, time_now_s))

        # Check whether performing another action is allowed
        time_delta = time_now_s - time_start
        performed_count = action_count - time_delta * rate_hz
        if performed_count < 0:
            # Allow, reset back to count 1
            allowed = True
            time_start = time_now_s
            action_count = 1.0
        elif performed_count > burst_count - 1.0:
            # Deny, we have exceeded our burst count
            allowed = False
        else:
            # We haven't reached our limit yet
            allowed = True
            action_count += 1.0

        if update:
            self.actions[key] = (action_count, time_start)

        # Figure out the time when an action can be performed again
        if self.rate_hz > 0:
            time_allowed = time_start + (action_count - burst_count + 1) / rate_hz

            # Don't give back a time in the past
            if time_allowed < time_now_s:
                time_allowed = time_now_s
        else:
            # This does not apply
            time_allowed = -1

        return allowed, time_allowed

    def _prune_message_counts(self, time_now_s: int, rate_hz: float):
        """Remove message count entries that have not exceeded their defined
        rate_hz limit

        Args:
            time_now_s: The current time
            rate_hz: The long term number of actions that can be performed in a second.
        """
        # We create a copy of the key list here as the dictionary is modified during
        # the loop
        for key in list(self.actions.keys()):
            action_count, time_start = self.actions[key]

            # Rate limit = "seconds since we started limiting this action" * rate_hz
            # If this limit has not been exceeded, wipe our record of this action
            time_delta = time_now_s - time_start
            if action_count - time_delta * rate_hz > 0:
                continue
            else:
                del self.actions[key]

    def ratelimit(
        self,
        key: Any,
        time_now_s: int,
        update: bool = True,
        rate_hz: Optional[float] = None,
        burst_count: Optional[int] = None,
    ):
        """Checks if an action can be performed. If not, raises a LimitExceededError

        Args:
            key: An arbitrary key used to classify an action
            time_now_s: The current time
            update: Whether to count this check as performing the action
            rate_hz: The long term number of actions that can be performed in a second.
                Overrides the value set during instantiation if set.
            burst_count: How many actions that can be performed before being limited.
                Overrides the value set during instantiation if set.

        Raises:
            LimitExceededError: If an action could not be performed, along with the time in
                milliseconds until the action can be performed again
        """
        # Override default values if set
        rate_hz = rate_hz or self.rate_hz
        burst_count = burst_count or self.burst_count

        allowed, time_allowed = self.can_do_action(
            key, time_now_s, update=update, rate_hz=rate_hz, burst_count=burst_count
        )

        if not allowed:
            raise LimitExceededError(
                retry_after_ms=int(1000 * (time_allowed - time_now_s))
            )