summary refs log tree commit diff
path: root/synapse/storage/databases/main/room.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/room.py')
-rw-r--r--synapse/storage/databases/main/room.py64
1 files changed, 59 insertions, 5 deletions
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 9cbcd53026..47fb12f3f6 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -521,13 +521,11 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=10000)
-    async def get_ratelimit_for_user(self, user_id):
-        """Check if there are any overrides for ratelimiting for the given
-        user
+    async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]:
+        """Check if there are any overrides for ratelimiting for the given user
 
         Args:
-            user_id (str)
-
+            user_id: user ID of the user
         Returns:
             RatelimitOverride if there is an override, else None. If the contents
             of RatelimitOverride are None or 0 then ratelimitng has been
@@ -549,6 +547,62 @@ class RoomWorkerStore(SQLBaseStore):
         else:
             return None
 
+    async def set_ratelimit_for_user(
+        self, user_id: str, messages_per_second: int, burst_count: int
+    ) -> None:
+        """Sets whether a user is set an overridden ratelimit.
+        Args:
+            user_id: user ID of the user
+            messages_per_second: The number of actions that can be performed in a second.
+            burst_count: How many actions that can be performed before being limited.
+        """
+
+        def set_ratelimit_txn(txn):
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="ratelimit_override",
+                keyvalues={"user_id": user_id},
+                values={
+                    "messages_per_second": messages_per_second,
+                    "burst_count": burst_count,
+                },
+            )
+
+            self._invalidate_cache_and_stream(
+                txn, self.get_ratelimit_for_user, (user_id,)
+            )
+
+        await self.db_pool.runInteraction("set_ratelimit", set_ratelimit_txn)
+
+    async def delete_ratelimit_for_user(self, user_id: str) -> None:
+        """Delete an overridden ratelimit for a user.
+        Args:
+            user_id: user ID of the user
+        """
+
+        def delete_ratelimit_txn(txn):
+            row = self.db_pool.simple_select_one_txn(
+                txn,
+                table="ratelimit_override",
+                keyvalues={"user_id": user_id},
+                retcols=["user_id"],
+                allow_none=True,
+            )
+
+            if not row:
+                return
+
+            # They are there, delete them.
+            self.db_pool.simple_delete_one_txn(
+                txn, "ratelimit_override", keyvalues={"user_id": user_id}
+            )
+
+            self._invalidate_cache_and_stream(
+                txn, self.get_ratelimit_for_user, (user_id,)
+            )
+
+        await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
+
     @cached()
     async def get_retention_policy_for_room(self, room_id):
         """Get the retention policy for a given room.