summary refs log tree commit diff
path: root/synapse/storage/room.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/room.py')
-rw-r--r--synapse/storage/room.py36
1 files changed, 35 insertions, 1 deletions
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index e4c56cc175..5d543652bb 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 from ._base import SQLBaseStore
 from .engines import PostgresEngine, Sqlite3Engine
@@ -33,6 +33,11 @@ OpsLevel = collections.namedtuple(
     ("ban_level", "kick_level", "redact_level",)
 )
 
+RatelimitOverride = collections.namedtuple(
+    "RatelimitOverride",
+    ("messages_per_second", "burst_count",)
+)
+
 
 class RoomStore(SQLBaseStore):
 
@@ -473,3 +478,32 @@ class RoomStore(SQLBaseStore):
         return self.runInteraction(
             "get_all_new_public_rooms", get_all_new_public_rooms
         )
+
+    @cachedInlineCallbacks(max_entries=10000)
+    def get_ratelimit_for_user(self, user_id):
+        """Check if there are any overrides for ratelimiting for the given
+        user
+
+        Args:
+            user_id (str)
+
+        Returns:
+            RatelimitOverride if there is an override, else None. If the contents
+            of RatelimitOverride are None or 0 then ratelimitng has been
+            disabled for that user entirely.
+        """
+        row = yield self._simple_select_one(
+            table="ratelimit_override",
+            keyvalues={"user_id": user_id},
+            retcols=("messages_per_second", "burst_count"),
+            allow_none=True,
+            desc="get_ratelimit_for_user",
+        )
+
+        if row:
+            defer.returnValue(RatelimitOverride(
+                messages_per_second=row["messages_per_second"],
+                burst_count=row["burst_count"],
+            ))
+        else:
+            defer.returnValue(None)