summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/registration_tokens.py21
-rw-r--r--synapse/rest/admin/users.py8
-rw-r--r--synapse/rest/client/login.py6
-rw-r--r--synapse/rest/client/login_token_request.py10
-rw-r--r--synapse/rest/client/presence.py2
-rw-r--r--synapse/rest/client/read_marker.py4
-rw-r--r--synapse/rest/client/receipts.py4
-rw-r--r--synapse/rest/client/register.py3
-rw-r--r--synapse/rest/client/report_event.py2
-rw-r--r--synapse/rest/client/room.py4
-rw-r--r--synapse/rest/client/sync.py1
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py39
13 files changed, 72 insertions, 34 deletions
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index fe8177ed4d..0d42c89ff7 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -156,7 +156,7 @@ class PurgeHistoryRestServlet(RestServlet):
             logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
         elif "purge_up_to_ts" in body:
             ts = body["purge_up_to_ts"]
-            if type(ts) is not int:
+            if type(ts) is not int:  # noqa: E721
                 raise SynapseError(
                     HTTPStatus.BAD_REQUEST,
                     "purge_up_to_ts must be an int",
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index 95e751288b..ffce92d45e 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -143,7 +143,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
         else:
             # Get length of token to generate (default is 16)
             length = body.get("length", 16)
-            if type(length) is not int:
+            if type(length) is not int:  # noqa: E721
                 raise SynapseError(
                     HTTPStatus.BAD_REQUEST,
                     "length must be an integer",
@@ -163,7 +163,8 @@ class NewRegistrationTokenRestServlet(RestServlet):
 
         uses_allowed = body.get("uses_allowed", None)
         if not (
-            uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0)
+            uses_allowed is None
+            or (type(uses_allowed) is int and uses_allowed >= 0)  # noqa: E721
         ):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
@@ -172,13 +173,16 @@ class NewRegistrationTokenRestServlet(RestServlet):
             )
 
         expiry_time = body.get("expiry_time", None)
-        if type(expiry_time) not in (int, type(None)):
+        if expiry_time is not None and type(expiry_time) is not int:  # noqa: E721
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "expiry_time must be an integer or null",
                 Codes.INVALID_PARAM,
             )
-        if type(expiry_time) is int and expiry_time < self.clock.time_msec():
+        if (
+            type(expiry_time) is int  # noqa: E721
+            and expiry_time < self.clock.time_msec()
+        ):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "expiry_time must not be in the past",
@@ -283,7 +287,7 @@ class RegistrationTokenRestServlet(RestServlet):
             uses_allowed = body["uses_allowed"]
             if not (
                 uses_allowed is None
-                or (type(uses_allowed) is int and uses_allowed >= 0)
+                or (type(uses_allowed) is int and uses_allowed >= 0)  # noqa: E721
             ):
                 raise SynapseError(
                     HTTPStatus.BAD_REQUEST,
@@ -294,13 +298,16 @@ class RegistrationTokenRestServlet(RestServlet):
 
         if "expiry_time" in body:
             expiry_time = body["expiry_time"]
-            if type(expiry_time) not in (int, type(None)):
+            if expiry_time is not None and type(expiry_time) is not int:  # noqa: E721
                 raise SynapseError(
                     HTTPStatus.BAD_REQUEST,
                     "expiry_time must be an integer or null",
                     Codes.INVALID_PARAM,
                 )
-            if type(expiry_time) is int and expiry_time < self.clock.time_msec():
+            if (
+                type(expiry_time) is int  # noqa: E721
+                and expiry_time < self.clock.time_msec()
+            ):
                 raise SynapseError(
                     HTTPStatus.BAD_REQUEST,
                     "expiry_time must not be in the past",
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 240e6254b0..91898a5c13 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -132,6 +132,7 @@ class UsersRestServletV2(RestServlet):
                 UserSortOrder.AVATAR_URL.value,
                 UserSortOrder.SHADOW_BANNED.value,
                 UserSortOrder.CREATION_TS.value,
+                UserSortOrder.LAST_SEEN_TS.value,
             ),
         )
 
@@ -1172,14 +1173,17 @@ class RateLimitRestServlet(RestServlet):
         messages_per_second = body.get("messages_per_second", 0)
         burst_count = body.get("burst_count", 0)
 
-        if type(messages_per_second) is not int or messages_per_second < 0:
+        if (
+            type(messages_per_second) is not int  # noqa: E721
+            or messages_per_second < 0
+        ):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "%r parameter must be a positive int" % (messages_per_second,),
                 errcode=Codes.INVALID_PARAM,
             )
 
-        if type(burst_count) is not int or burst_count < 0:
+        if type(burst_count) is not int or burst_count < 0:  # noqa: E721
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "%r parameter must be a positive int" % (burst_count,),
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index d724c68920..7be327e26f 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -120,14 +120,12 @@ class LoginRestServlet(RestServlet):
         self._address_ratelimiter = Ratelimiter(
             store=self._main_store,
             clock=hs.get_clock(),
-            rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second,
-            burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count,
+            cfg=self.hs.config.ratelimiting.rc_login_address,
         )
         self._account_ratelimiter = Ratelimiter(
             store=self._main_store,
             clock=hs.get_clock(),
-            rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second,
-            burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count,
+            cfg=self.hs.config.ratelimiting.rc_login_account,
         )
 
         # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py
index b1629f94a5..d189a923b5 100644
--- a/synapse/rest/client/login_token_request.py
+++ b/synapse/rest/client/login_token_request.py
@@ -16,6 +16,7 @@ import logging
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.config.ratelimiting import RatelimitSettings
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
@@ -66,15 +67,18 @@ class LoginTokenRequestServlet(RestServlet):
         self.token_timeout = hs.config.auth.login_via_existing_token_timeout
         self._require_ui_auth = hs.config.auth.login_via_existing_require_ui_auth
 
-        # Ratelimit aggressively to a maxmimum of 1 request per minute.
+        # Ratelimit aggressively to a maximum of 1 request per minute.
         #
         # This endpoint can be used to spawn additional sessions and could be
         # abused by a malicious client to create many sessions.
         self._ratelimiter = Ratelimiter(
             store=self._main_store,
             clock=hs.get_clock(),
-            rate_hz=1 / 60,
-            burst_count=1,
+            cfg=RatelimitSettings(
+                key="<login token request>",
+                per_second=1 / 60,
+                burst_count=1,
+            ),
         )
 
     @interactive_auth_handler
diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py
index 8e193330f8..d578faa969 100644
--- a/synapse/rest/client/presence.py
+++ b/synapse/rest/client/presence.py
@@ -97,7 +97,7 @@ class PresenceStatusRestServlet(RestServlet):
             raise SynapseError(400, "Unable to parse state")
 
         if self._use_presence:
-            await self.presence_handler.set_state(user, state)
+            await self.presence_handler.set_state(user, requester.device_id, state)
 
         return 200, {}
 
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 4f96e51eeb..1707e51972 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -52,7 +52,9 @@ class ReadMarkerRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
-        await self.presence_handler.bump_presence_active_time(requester.user)
+        await self.presence_handler.bump_presence_active_time(
+            requester.user, requester.device_id
+        )
 
         body = parse_json_object_from_request(request)
 
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 316e7b9982..869a374459 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -94,7 +94,9 @@ class ReceiptRestServlet(RestServlet):
                     Codes.INVALID_PARAM,
                 )
 
-        await self.presence_handler.bump_presence_active_time(requester.user)
+        await self.presence_handler.bump_presence_active_time(
+            requester.user, requester.device_id
+        )
 
         if receipt_type == ReceiptTypes.FULLY_READ:
             await self.read_marker_handler.received_client_read_marker(
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 77e3b91b79..132623462a 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -376,8 +376,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
         self.ratelimiter = Ratelimiter(
             store=self.store,
             clock=hs.get_clock(),
-            rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
-            burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
+            cfg=hs.config.ratelimiting.rc_registration_token_validity,
         )
 
     async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py
index ac1a63ca27..ee93e459f6 100644
--- a/synapse/rest/client/report_event.py
+++ b/synapse/rest/client/report_event.py
@@ -55,7 +55,7 @@ class ReportEventRestServlet(RestServlet):
                 "Param 'reason' must be a string",
                 Codes.BAD_JSON,
             )
-        if type(body.get("score", 0)) is not int:
+        if type(body.get("score", 0)) is not int:  # noqa: E721
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Param 'score' must be an integer",
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index dc498001e4..553938ce9d 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -1229,7 +1229,9 @@ class RoomTypingRestServlet(RestServlet):
 
         content = parse_json_object_from_request(request)
 
-        await self.presence_handler.bump_presence_active_time(requester.user)
+        await self.presence_handler.bump_presence_active_time(
+            requester.user, requester.device_id
+        )
 
         # Limit timeout to stop people from setting silly typing timeouts.
         timeout = min(content.get("timeout", 30000), 120000)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index d7854ed4fd..42bdd3bb10 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -205,6 +205,7 @@ class SyncRestServlet(RestServlet):
 
         context = await self.presence_handler.user_syncing(
             user.to_string(),
+            requester.device_id,
             affect_presence=affect_presence,
             presence_state=set_presence,
         )
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 981fd1f58a..0aaa838d04 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -16,6 +16,7 @@ import logging
 import re
 from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
 
+from pydantic import Extra, StrictInt, StrictStr
 from signedjson.sign import sign_json
 
 from twisted.web.server import Request
@@ -24,9 +25,10 @@ from synapse.crypto.keyring import ServerKeyFetcher
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
+    parse_and_validate_json_object_from_request,
     parse_integer,
-    parse_json_object_from_request,
 )
+from synapse.rest.models import RequestBodyModel
 from synapse.storage.keys import FetchKeyResultForRemote
 from synapse.types import JsonDict
 from synapse.util import json_decoder
@@ -38,6 +40,13 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class _KeyQueryCriteriaDataModel(RequestBodyModel):
+    class Config:
+        extra = Extra.allow
+
+    minimum_valid_until_ts: Optional[StrictInt]
+
+
 class RemoteKey(RestServlet):
     """HTTP resource for retrieving the TLS certificate and NACL signature
     verification keys for a collection of servers. Checks that the reported
@@ -96,6 +105,9 @@ class RemoteKey(RestServlet):
 
     CATEGORY = "Federation requests"
 
+    class PostBody(RequestBodyModel):
+        server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]]
+
     def __init__(self, hs: "HomeServer"):
         self.fetcher = ServerKeyFetcher(hs)
         self.store = hs.get_datastores().main
@@ -137,24 +149,29 @@ class RemoteKey(RestServlet):
             )
 
             minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
-            arguments = {}
-            if minimum_valid_until_ts is not None:
-                arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
-            query = {server: {key_id: arguments}}
+            query = {
+                server: {
+                    key_id: _KeyQueryCriteriaDataModel(
+                        minimum_valid_until_ts=minimum_valid_until_ts
+                    )
+                }
+            }
         else:
             query = {server: {}}
 
         return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
 
     async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
+        content = parse_and_validate_json_object_from_request(request, self.PostBody)
 
-        query = content["server_keys"]
+        query = content.server_keys
 
         return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
 
     async def query_keys(
-        self, query: JsonDict, query_remote_on_cache_miss: bool = False
+        self,
+        query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]],
+        query_remote_on_cache_miss: bool = False,
     ) -> JsonDict:
         logger.info("Handling query for keys %r", query)
 
@@ -196,8 +213,10 @@ class RemoteKey(RestServlet):
             else:
                 ts_added_ms = key_result.added_ts
                 ts_valid_until_ms = key_result.valid_until_ts
-                req_key = query.get(server_name, {}).get(key_id, {})
-                req_valid_until = req_key.get("minimum_valid_until_ts")
+                req_key = query.get(server_name, {}).get(
+                    key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None)
+                )
+                req_valid_until = req_key.minimum_valid_until_ts
                 if req_valid_until is not None:
                     if ts_valid_until_ms < req_valid_until:
                         logger.debug(