summary refs log tree commit diff
path: root/synapse/api/ratelimiting.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/api/ratelimiting.py')
-rw-r--r--synapse/api/ratelimiting.py43
1 files changed, 28 insertions, 15 deletions
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 3bb5b3da37..296c4a1c17 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -14,6 +14,8 @@
 
 import collections
 
+from synapse.api.errors import LimitExceededError
+
 
 class Ratelimiter(object):
     """
@@ -23,12 +25,13 @@ class Ratelimiter(object):
     def __init__(self):
         self.message_counts = collections.OrderedDict()
 
-    def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
-        """Can the user send a message?
+    def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
+        """Can the entity (e.g. user or IP address) perform the action?
         Args:
-            user_id: The user sending a message.
+            key: The key we should use when rate limiting. Can be a user ID
+                (when sending events), an IP address, etc.
             time_now_s: The time now.
-            msg_rate_hz: The long term number of messages a user can send in a
+            rate_hz: The long term number of messages a user can send in a
                 second.
             burst_count: How many messages the user can send before being
                 limited.
@@ -41,10 +44,10 @@ class Ratelimiter(object):
         """
         self.prune_message_counts(time_now_s)
         message_count, time_start, _ignored = self.message_counts.get(
-            user_id, (0., time_now_s, None),
+            key, (0., time_now_s, None),
         )
         time_delta = time_now_s - time_start
-        sent_count = message_count - time_delta * msg_rate_hz
+        sent_count = message_count - time_delta * rate_hz
         if sent_count < 0:
             allowed = True
             time_start = time_now_s
@@ -56,13 +59,13 @@ class Ratelimiter(object):
             message_count += 1
 
         if update:
-            self.message_counts[user_id] = (
-                message_count, time_start, msg_rate_hz
+            self.message_counts[key] = (
+                message_count, time_start, rate_hz
             )
 
-        if msg_rate_hz > 0:
+        if rate_hz > 0:
             time_allowed = (
-                time_start + (message_count - burst_count + 1) / msg_rate_hz
+                time_start + (message_count - burst_count + 1) / rate_hz
             )
             if time_allowed < time_now_s:
                 time_allowed = time_now_s
@@ -72,12 +75,22 @@ class Ratelimiter(object):
         return allowed, time_allowed
 
     def prune_message_counts(self, time_now_s):
-        for user_id in list(self.message_counts.keys()):
-            message_count, time_start, msg_rate_hz = (
-                self.message_counts[user_id]
+        for key in list(self.message_counts.keys()):
+            message_count, time_start, rate_hz = (
+                self.message_counts[key]
             )
             time_delta = time_now_s - time_start
-            if message_count - time_delta * msg_rate_hz > 0:
+            if message_count - time_delta * rate_hz > 0:
                 break
             else:
-                del self.message_counts[user_id]
+                del self.message_counts[key]
+
+    def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True):
+        allowed, time_allowed = self.can_do_action(
+            key, time_now_s, rate_hz, burst_count, update
+        )
+
+        if not allowed:
+            raise LimitExceededError(
+                retry_after_ms=int(1000 * (time_allowed - time_now_s)),
+            )