summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/users.py111
-rw-r--r--synapse/storage/databases/main/room.py64
3 files changed, 172 insertions, 5 deletions
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 5daa795df1..2dec818a5f 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -54,6 +54,7 @@ from synapse.rest.admin.users import (
     AccountValidityRenewServlet,
     DeactivateAccountRestServlet,
     PushersRestServlet,
+    RateLimitRestServlet,
     ResetPasswordRestServlet,
     SearchUsersRestServlet,
     ShadowBanRestServlet,
@@ -239,6 +240,7 @@ def register_servlets(hs, http_server):
     ShadowBanRestServlet(hs).register(http_server)
     ForwardExtremitiesRestServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
+    RateLimitRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 595898c259..04990c71fb 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -981,3 +981,114 @@ class ShadowBanRestServlet(RestServlet):
         await self.store.set_shadow_banned(UserID.from_string(user_id), True)
 
         return 200, {}
+
+
+class RateLimitRestServlet(RestServlet):
+    """An admin API to override ratelimiting for an user.
+
+    Example:
+        POST /_synapse/admin/v1/users/@test:example.com/override_ratelimit
+        {
+          "messages_per_second": 0,
+          "burst_count": 0
+        }
+        200 OK
+        {
+          "messages_per_second": 0,
+          "burst_count": 0
+        }
+    """
+
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.hs.is_mine_id(user_id):
+            raise SynapseError(400, "Can only lookup local users")
+
+        if not await self.store.get_user_by_id(user_id):
+            raise NotFoundError("User not found")
+
+        ratelimit = await self.store.get_ratelimit_for_user(user_id)
+
+        if ratelimit:
+            # convert `null` to `0` for consistency
+            # both values do the same in retelimit handler
+            ret = {
+                "messages_per_second": 0
+                if ratelimit.messages_per_second is None
+                else ratelimit.messages_per_second,
+                "burst_count": 0
+                if ratelimit.burst_count is None
+                else ratelimit.burst_count,
+            }
+        else:
+            ret = {}
+
+        return 200, ret
+
+    async def on_POST(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.hs.is_mine_id(user_id):
+            raise SynapseError(400, "Only local users can be ratelimited")
+
+        if not await self.store.get_user_by_id(user_id):
+            raise NotFoundError("User not found")
+
+        body = parse_json_object_from_request(request, allow_empty_body=True)
+
+        messages_per_second = body.get("messages_per_second", 0)
+        burst_count = body.get("burst_count", 0)
+
+        if not isinstance(messages_per_second, int) or messages_per_second < 0:
+            raise SynapseError(
+                400,
+                "%r parameter must be a positive int" % (messages_per_second,),
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if not isinstance(burst_count, int) or burst_count < 0:
+            raise SynapseError(
+                400,
+                "%r parameter must be a positive int" % (burst_count,),
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        await self.store.set_ratelimit_for_user(
+            user_id, messages_per_second, burst_count
+        )
+        ratelimit = await self.store.get_ratelimit_for_user(user_id)
+        assert ratelimit is not None
+
+        ret = {
+            "messages_per_second": ratelimit.messages_per_second,
+            "burst_count": ratelimit.burst_count,
+        }
+
+        return 200, ret
+
+    async def on_DELETE(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.hs.is_mine_id(user_id):
+            raise SynapseError(400, "Only local users can be ratelimited")
+
+        if not await self.store.get_user_by_id(user_id):
+            raise NotFoundError("User not found")
+
+        await self.store.delete_ratelimit_for_user(user_id)
+
+        return 200, {}
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 9cbcd53026..47fb12f3f6 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -521,13 +521,11 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=10000)
-    async def get_ratelimit_for_user(self, user_id):
-        """Check if there are any overrides for ratelimiting for the given
-        user
+    async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]:
+        """Check if there are any overrides for ratelimiting for the given user
 
         Args:
-            user_id (str)
-
+            user_id: user ID of the user
         Returns:
             RatelimitOverride if there is an override, else None. If the contents
             of RatelimitOverride are None or 0 then ratelimitng has been
@@ -549,6 +547,62 @@ class RoomWorkerStore(SQLBaseStore):
         else:
             return None
 
+    async def set_ratelimit_for_user(
+        self, user_id: str, messages_per_second: int, burst_count: int
+    ) -> None:
+        """Sets whether a user is set an overridden ratelimit.
+        Args:
+            user_id: user ID of the user
+            messages_per_second: The number of actions that can be performed in a second.
+            burst_count: How many actions that can be performed before being limited.
+        """
+
+        def set_ratelimit_txn(txn):
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="ratelimit_override",
+                keyvalues={"user_id": user_id},
+                values={
+                    "messages_per_second": messages_per_second,
+                    "burst_count": burst_count,
+                },
+            )
+
+            self._invalidate_cache_and_stream(
+                txn, self.get_ratelimit_for_user, (user_id,)
+            )
+
+        await self.db_pool.runInteraction("set_ratelimit", set_ratelimit_txn)
+
+    async def delete_ratelimit_for_user(self, user_id: str) -> None:
+        """Delete an overridden ratelimit for a user.
+        Args:
+            user_id: user ID of the user
+        """
+
+        def delete_ratelimit_txn(txn):
+            row = self.db_pool.simple_select_one_txn(
+                txn,
+                table="ratelimit_override",
+                keyvalues={"user_id": user_id},
+                retcols=["user_id"],
+                allow_none=True,
+            )
+
+            if not row:
+                return
+
+            # They are there, delete them.
+            self.db_pool.simple_delete_one_txn(
+                txn, "ratelimit_override", keyvalues={"user_id": user_id}
+            )
+
+            self._invalidate_cache_and_stream(
+                txn, self.get_ratelimit_for_user, (user_id,)
+            )
+
+        await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
+
     @cached()
     async def get_retention_policy_for_room(self, room_id):
         """Get the retention policy for a given room.