summary refs log tree commit diff
path: root/synapse/rest/client/devices.py
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-10-07 13:54:07 +0100
committerGitHub <noreply@github.com>2022-10-07 13:54:07 +0100
commit2295095c97f3b4707f30ae8cb4562ebb799f7ac1 (patch)
tree2c084ff17e30f123bd0f4aea0191ed45d6a9238f /synapse/rest/client/devices.py
parentCatch BrokenPipeError from metrics server, and log as a warning (#14072) (diff)
downloadsynapse-2295095c97f3b4707f30ae8cb4562ebb799f7ac1.tar.xz
Use Pydantic to validate /devices endpoints (#14054)
Diffstat (limited to '')
-rw-r--r--synapse/rest/client/devices.py98
1 files changed, 52 insertions, 46 deletions
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index ed6ce78d47..90828c95c4 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -14,18 +14,21 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from pydantic import Extra, StrictStr
 
 from synapse.api import errors
 from synapse.api.errors import NotFoundError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
-    assert_params_in_dict,
-    parse_json_object_from_request,
+    parse_and_validate_json_object_from_request,
 )
 from synapse.http.site import SynapseRequest
 from synapse.rest.client._base import client_patterns, interactive_auth_handler
+from synapse.rest.client.models import AuthenticationData
+from synapse.rest.models import RequestBodyModel
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -80,27 +83,29 @@ class DeleteDevicesRestServlet(RestServlet):
         self.device_handler = hs.get_device_handler()
         self.auth_handler = hs.get_auth_handler()
 
+    class PostBody(RequestBodyModel):
+        auth: Optional[AuthenticationData]
+        devices: List[StrictStr]
+
     @interactive_auth_handler
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         try:
-            body = parse_json_object_from_request(request)
+            body = parse_and_validate_json_object_from_request(request, self.PostBody)
         except errors.SynapseError as e:
             if e.errcode == errors.Codes.NOT_JSON:
-                # DELETE
+                # TODO: Can/should we remove this fallback now?
                 # deal with older clients which didn't pass a JSON dict
                 # the same as those that pass an empty dict
-                body = {}
+                body = self.PostBody.parse_obj({})
             else:
                 raise e
 
-        assert_params_in_dict(body, ["devices"])
-
         await self.auth_handler.validate_user_via_ui_auth(
             requester,
             request,
-            body,
+            body.dict(exclude_unset=True),
             "remove device(s) from your account",
             # Users might call this multiple times in a row while cleaning up
             # devices, allow a single UI auth session to be re-used.
@@ -108,7 +113,7 @@ class DeleteDevicesRestServlet(RestServlet):
         )
 
         await self.device_handler.delete_devices(
-            requester.user.to_string(), body["devices"]
+            requester.user.to_string(), body.devices
         )
         return 200, {}
 
@@ -147,6 +152,9 @@ class DeviceRestServlet(RestServlet):
 
         return 200, device
 
+    class DeleteBody(RequestBodyModel):
+        auth: Optional[AuthenticationData]
+
     @interactive_auth_handler
     async def on_DELETE(
         self, request: SynapseRequest, device_id: str
@@ -154,20 +162,21 @@ class DeviceRestServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request)
 
         try:
-            body = parse_json_object_from_request(request)
+            body = parse_and_validate_json_object_from_request(request, self.DeleteBody)
 
         except errors.SynapseError as e:
             if e.errcode == errors.Codes.NOT_JSON:
+                # TODO: can/should we remove this fallback now?
                 # deal with older clients which didn't pass a JSON dict
                 # the same as those that pass an empty dict
-                body = {}
+                body = self.DeleteBody.parse_obj({})
             else:
                 raise
 
         await self.auth_handler.validate_user_via_ui_auth(
             requester,
             request,
-            body,
+            body.dict(exclude_unset=True),
             "remove a device from your account",
             # Users might call this multiple times in a row while cleaning up
             # devices, allow a single UI auth session to be re-used.
@@ -179,18 +188,33 @@ class DeviceRestServlet(RestServlet):
         )
         return 200, {}
 
+    class PutBody(RequestBodyModel):
+        display_name: Optional[StrictStr]
+
     async def on_PUT(
         self, request: SynapseRequest, device_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
-        body = parse_json_object_from_request(request)
+        body = parse_and_validate_json_object_from_request(request, self.PutBody)
         await self.device_handler.update_device(
-            requester.user.to_string(), device_id, body
+            requester.user.to_string(), device_id, body.dict()
         )
         return 200, {}
 
 
+class DehydratedDeviceDataModel(RequestBodyModel):
+    """JSON blob describing a dehydrated device to be stored.
+
+    Expects other freeform fields. Use .dict() to access them.
+    """
+
+    class Config:
+        extra = Extra.allow
+
+    algorithm: StrictStr
+
+
 class DehydratedDeviceServlet(RestServlet):
     """Retrieve or store a dehydrated device.
 
@@ -246,27 +270,19 @@ class DehydratedDeviceServlet(RestServlet):
         else:
             raise errors.NotFoundError("No dehydrated device available")
 
+    class PutBody(RequestBodyModel):
+        device_id: StrictStr
+        device_data: DehydratedDeviceDataModel
+        initial_device_display_name: Optional[StrictStr]
+
     async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        submission = parse_json_object_from_request(request)
+        submission = parse_and_validate_json_object_from_request(request, self.PutBody)
         requester = await self.auth.get_user_by_req(request)
 
-        if "device_data" not in submission:
-            raise errors.SynapseError(
-                400,
-                "device_data missing",
-                errcode=errors.Codes.MISSING_PARAM,
-            )
-        elif not isinstance(submission["device_data"], dict):
-            raise errors.SynapseError(
-                400,
-                "device_data must be an object",
-                errcode=errors.Codes.INVALID_PARAM,
-            )
-
         device_id = await self.device_handler.store_dehydrated_device(
             requester.user.to_string(),
-            submission["device_data"],
-            submission.get("initial_device_display_name", None),
+            submission.device_data,
+            submission.initial_device_display_name,
         )
         return 200, {"device_id": device_id}
 
@@ -300,28 +316,18 @@ class ClaimDehydratedDeviceServlet(RestServlet):
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
 
+    class PostBody(RequestBodyModel):
+        device_id: StrictStr
+
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
-        submission = parse_json_object_from_request(request)
-
-        if "device_id" not in submission:
-            raise errors.SynapseError(
-                400,
-                "device_id missing",
-                errcode=errors.Codes.MISSING_PARAM,
-            )
-        elif not isinstance(submission["device_id"], str):
-            raise errors.SynapseError(
-                400,
-                "device_id must be a string",
-                errcode=errors.Codes.INVALID_PARAM,
-            )
+        submission = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         result = await self.device_handler.rehydrate_device(
             requester.user.to_string(),
             self.auth.get_access_token_from_request(request),
-            submission["device_id"],
+            submission.device_id,
         )
 
         return 200, result