summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9648.feature1
-rw-r--r--docs/admin_api/user_admin_api.rst117
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/users.py111
-rw-r--r--synapse/storage/databases/main/room.py64
-rw-r--r--tests/rest/admin/test_user.py284
6 files changed, 573 insertions, 6 deletions
diff --git a/changelog.d/9648.feature b/changelog.d/9648.feature
new file mode 100644
index 0000000000..bc77026039
--- /dev/null
+++ b/changelog.d/9648.feature
@@ -0,0 +1 @@
+Add an admin API to manage ratelimit for a specific user.
\ No newline at end of file
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index a8a5a2628c..dbce9c90b6 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -202,7 +202,7 @@ The following fields are returned in the JSON response body:
 - ``users`` - An array of objects, each containing information about an user.
   User objects contain the following fields:
 
-  - ``name`` - string - Fully-qualified user ID (ex. `@user:server.com`).
+  - ``name`` - string - Fully-qualified user ID (ex. ``@user:server.com``).
   - ``is_guest`` - bool - Status if that user is a guest account.
   - ``admin`` - bool - Status if that user is a server administrator.
   - ``user_type`` - string - Type of the user. Normal users are type ``None``.
@@ -864,3 +864,118 @@ The following parameters should be set in the URL:
 
 - ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
   be local.
+
+Override ratelimiting for users
+===============================
+
+This API allows to override or disable ratelimiting for a specific user.
+There are specific APIs to set, get and delete a ratelimit.
+
+Get status of ratelimit
+-----------------------
+
+The API is::
+
+  GET /_synapse/admin/v1/users/<user_id>/override_ratelimit
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+A response body like the following is returned:
+
+.. code:: json
+
+    {
+      "messages_per_second": 0,
+      "burst_count": 0
+    }
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
+  be local.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``messages_per_second`` - integer - The number of actions that can
+  be performed in a second. `0` mean that ratelimiting is disabled for this user.
+- ``burst_count`` - integer - How many actions that can be performed before
+  being limited.
+
+If **no** custom ratelimit is set, an empty JSON dict is returned.
+
+.. code:: json
+
+    {}
+
+Set ratelimit
+-------------
+
+The API is::
+
+  POST /_synapse/admin/v1/users/<user_id>/override_ratelimit
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+A response body like the following is returned:
+
+.. code:: json
+
+    {
+      "messages_per_second": 0,
+      "burst_count": 0
+    }
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
+  be local.
+
+Body parameters:
+
+- ``messages_per_second`` - positive integer, optional. The number of actions that can
+  be performed in a second. Defaults to ``0``.
+- ``burst_count`` - positive integer, optional. How many actions that can be performed
+  before being limited. Defaults to ``0``.
+
+To disable users' ratelimit set both values to ``0``.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``messages_per_second`` - integer - The number of actions that can
+  be performed in a second.
+- ``burst_count`` - integer - How many actions that can be performed before
+  being limited.
+
+Delete ratelimit
+----------------
+
+The API is::
+
+  DELETE /_synapse/admin/v1/users/<user_id>/override_ratelimit
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+An empty JSON dict is returned.
+
+.. code:: json
+
+    {}
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
+  be local.
+
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 5daa795df1..2dec818a5f 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -54,6 +54,7 @@ from synapse.rest.admin.users import (
     AccountValidityRenewServlet,
     DeactivateAccountRestServlet,
     PushersRestServlet,
+    RateLimitRestServlet,
     ResetPasswordRestServlet,
     SearchUsersRestServlet,
     ShadowBanRestServlet,
@@ -239,6 +240,7 @@ def register_servlets(hs, http_server):
     ShadowBanRestServlet(hs).register(http_server)
     ForwardExtremitiesRestServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
+    RateLimitRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 595898c259..04990c71fb 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -981,3 +981,114 @@ class ShadowBanRestServlet(RestServlet):
         await self.store.set_shadow_banned(UserID.from_string(user_id), True)
 
         return 200, {}
+
+
+class RateLimitRestServlet(RestServlet):
+    """An admin API to override ratelimiting for an user.
+
+    Example:
+        POST /_synapse/admin/v1/users/@test:example.com/override_ratelimit
+        {
+          "messages_per_second": 0,
+          "burst_count": 0
+        }
+        200 OK
+        {
+          "messages_per_second": 0,
+          "burst_count": 0
+        }
+    """
+
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.hs.is_mine_id(user_id):
+            raise SynapseError(400, "Can only lookup local users")
+
+        if not await self.store.get_user_by_id(user_id):
+            raise NotFoundError("User not found")
+
+        ratelimit = await self.store.get_ratelimit_for_user(user_id)
+
+        if ratelimit:
+            # convert `null` to `0` for consistency
+            # both values do the same in retelimit handler
+            ret = {
+                "messages_per_second": 0
+                if ratelimit.messages_per_second is None
+                else ratelimit.messages_per_second,
+                "burst_count": 0
+                if ratelimit.burst_count is None
+                else ratelimit.burst_count,
+            }
+        else:
+            ret = {}
+
+        return 200, ret
+
+    async def on_POST(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.hs.is_mine_id(user_id):
+            raise SynapseError(400, "Only local users can be ratelimited")
+
+        if not await self.store.get_user_by_id(user_id):
+            raise NotFoundError("User not found")
+
+        body = parse_json_object_from_request(request, allow_empty_body=True)
+
+        messages_per_second = body.get("messages_per_second", 0)
+        burst_count = body.get("burst_count", 0)
+
+        if not isinstance(messages_per_second, int) or messages_per_second < 0:
+            raise SynapseError(
+                400,
+                "%r parameter must be a positive int" % (messages_per_second,),
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if not isinstance(burst_count, int) or burst_count < 0:
+            raise SynapseError(
+                400,
+                "%r parameter must be a positive int" % (burst_count,),
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        await self.store.set_ratelimit_for_user(
+            user_id, messages_per_second, burst_count
+        )
+        ratelimit = await self.store.get_ratelimit_for_user(user_id)
+        assert ratelimit is not None
+
+        ret = {
+            "messages_per_second": ratelimit.messages_per_second,
+            "burst_count": ratelimit.burst_count,
+        }
+
+        return 200, ret
+
+    async def on_DELETE(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.hs.is_mine_id(user_id):
+            raise SynapseError(400, "Only local users can be ratelimited")
+
+        if not await self.store.get_user_by_id(user_id):
+            raise NotFoundError("User not found")
+
+        await self.store.delete_ratelimit_for_user(user_id)
+
+        return 200, {}
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 9cbcd53026..47fb12f3f6 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -521,13 +521,11 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=10000)
-    async def get_ratelimit_for_user(self, user_id):
-        """Check if there are any overrides for ratelimiting for the given
-        user
+    async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]:
+        """Check if there are any overrides for ratelimiting for the given user
 
         Args:
-            user_id (str)
-
+            user_id: user ID of the user
         Returns:
             RatelimitOverride if there is an override, else None. If the contents
             of RatelimitOverride are None or 0 then ratelimitng has been
@@ -549,6 +547,62 @@ class RoomWorkerStore(SQLBaseStore):
         else:
             return None
 
+    async def set_ratelimit_for_user(
+        self, user_id: str, messages_per_second: int, burst_count: int
+    ) -> None:
+        """Sets whether a user is set an overridden ratelimit.
+        Args:
+            user_id: user ID of the user
+            messages_per_second: The number of actions that can be performed in a second.
+            burst_count: How many actions that can be performed before being limited.
+        """
+
+        def set_ratelimit_txn(txn):
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="ratelimit_override",
+                keyvalues={"user_id": user_id},
+                values={
+                    "messages_per_second": messages_per_second,
+                    "burst_count": burst_count,
+                },
+            )
+
+            self._invalidate_cache_and_stream(
+                txn, self.get_ratelimit_for_user, (user_id,)
+            )
+
+        await self.db_pool.runInteraction("set_ratelimit", set_ratelimit_txn)
+
+    async def delete_ratelimit_for_user(self, user_id: str) -> None:
+        """Delete an overridden ratelimit for a user.
+        Args:
+            user_id: user ID of the user
+        """
+
+        def delete_ratelimit_txn(txn):
+            row = self.db_pool.simple_select_one_txn(
+                txn,
+                table="ratelimit_override",
+                keyvalues={"user_id": user_id},
+                retcols=["user_id"],
+                allow_none=True,
+            )
+
+            if not row:
+                return
+
+            # They are there, delete them.
+            self.db_pool.simple_delete_one_txn(
+                txn, "ratelimit_override", keyvalues={"user_id": user_id}
+            )
+
+            self._invalidate_cache_and_stream(
+                txn, self.get_ratelimit_for_user, (user_id,)
+            )
+
+        await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
+
     @cached()
     async def get_retention_policy_for_room(self, room_id):
         """Get the retention policy for a given room.
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e47fd3ded8..5070c96984 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -3011,3 +3011,287 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
         # Ensure the user is shadow-banned (and the cache was cleared).
         result = self.get_success(self.store.get_user_by_access_token(other_user_token))
         self.assertTrue(result.shadow_banned)
+
+
+class RateLimitTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.url = (
+            "/_synapse/admin/v1/users/%s/override_ratelimit"
+            % urllib.parse.quote(self.other_user)
+        )
+
+    def test_no_auth(self):
+        """
+        Try to get information of a user without authentication.
+        """
+        channel = self.make_request("GET", self.url, b"{}")
+
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+        channel = self.make_request("POST", self.url, b"{}")
+
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+        channel = self.make_request("DELETE", self.url, b"{}")
+
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_no_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        other_user_token = self.login("user", "pass")
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=other_user_token,
+        )
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=other_user_token,
+        )
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            access_token=other_user_token,
+        )
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_user_does_not_exist(self):
+        """
+        Tests that a lookup for a user that does not exist returns a 404
+        """
+        url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
+
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+        channel = self.make_request(
+            "POST",
+            url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+        channel = self.make_request(
+            "DELETE",
+            url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_user_is_not_local(self):
+        """
+        Tests that a lookup for a user that is not a local returns a 400
+        """
+        url = (
+            "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit"
+        )
+
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+        channel = self.make_request(
+            "POST",
+            url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(
+            "Only local users can be ratelimited", channel.json_body["error"]
+        )
+
+        channel = self.make_request(
+            "DELETE",
+            url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(
+            "Only local users can be ratelimited", channel.json_body["error"]
+        )
+
+    def test_invalid_parameter(self):
+        """
+        If parameters are invalid, an error is returned.
+        """
+        # messages_per_second is a string
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"messages_per_second": "string"},
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # messages_per_second is negative
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"messages_per_second": -1},
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # burst_count is a string
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"burst_count": "string"},
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # burst_count is negative
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"burst_count": -1},
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+    def test_return_zero_when_null(self):
+        """
+        If values in database are `null` API should return an int `0`
+        """
+
+        self.get_success(
+            self.store.db_pool.simple_upsert(
+                table="ratelimit_override",
+                keyvalues={"user_id": self.other_user},
+                values={
+                    "messages_per_second": None,
+                    "burst_count": None,
+                },
+            )
+        )
+
+        # request status
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(0, channel.json_body["messages_per_second"])
+        self.assertEqual(0, channel.json_body["burst_count"])
+
+    def test_success(self):
+        """
+        Rate-limiting (set/update/delete) should succeed for an admin.
+        """
+        # request status
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertNotIn("messages_per_second", channel.json_body)
+        self.assertNotIn("burst_count", channel.json_body)
+
+        # set ratelimit
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"messages_per_second": 10, "burst_count": 11},
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(10, channel.json_body["messages_per_second"])
+        self.assertEqual(11, channel.json_body["burst_count"])
+
+        # update ratelimit
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"messages_per_second": 20, "burst_count": 21},
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(20, channel.json_body["messages_per_second"])
+        self.assertEqual(21, channel.json_body["burst_count"])
+
+        # request status
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(20, channel.json_body["messages_per_second"])
+        self.assertEqual(21, channel.json_body["burst_count"])
+
+        # delete ratelimit
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertNotIn("messages_per_second", channel.json_body)
+        self.assertNotIn("burst_count", channel.json_body)
+
+        # request status
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertNotIn("messages_per_second", channel.json_body)
+        self.assertNotIn("burst_count", channel.json_body)