summary refs log tree commit diff
path: root/synapse/rest/client/devices.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/rest/client/devices.py32
1 files changed, 21 insertions, 11 deletions
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 90828c95c4..486c6dbbc5 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr
 
 from synapse.api import errors
 from synapse.api.errors import NotFoundError
+from synapse.handlers.device import DeviceHandler
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
@@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_handler = handler
         self.auth_handler = hs.get_auth_handler()
 
     class PostBody(RequestBodyModel):
@@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_handler = handler
         self.auth_handler = hs.get_auth_handler()
         self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
@@ -231,7 +236,7 @@ class DehydratedDeviceServlet(RestServlet):
       }
     }
 
-    PUT /org.matrix.msc2697/dehydrated_device
+    PUT /org.matrix.msc2697.v2/dehydrated_device
     Content-Type: application/json
 
     {
@@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_handler = handler
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
@@ -271,7 +278,6 @@ class DehydratedDeviceServlet(RestServlet):
             raise errors.NotFoundError("No dehydrated device available")
 
     class PutBody(RequestBodyModel):
-        device_id: StrictStr
         device_data: DehydratedDeviceDataModel
         initial_device_display_name: Optional[StrictStr]
 
@@ -281,7 +287,7 @@ class DehydratedDeviceServlet(RestServlet):
 
         device_id = await self.device_handler.store_dehydrated_device(
             requester.user.to_string(),
-            submission.device_data,
+            submission.device_data.dict(),
             submission.initial_device_display_name,
         )
         return 200, {"device_id": device_id}
@@ -314,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_handler = handler
 
     class PostBody(RequestBodyModel):
         device_id: StrictStr
@@ -334,8 +342,10 @@ class ClaimDehydratedDeviceServlet(RestServlet):
 
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
-    DeleteDevicesRestServlet(hs).register(http_server)
+    if hs.config.worker.worker_app is None:
+        DeleteDevicesRestServlet(hs).register(http_server)
     DevicesRestServlet(hs).register(http_server)
-    DeviceRestServlet(hs).register(http_server)
-    DehydratedDeviceServlet(hs).register(http_server)
-    ClaimDehydratedDeviceServlet(hs).register(http_server)
+    if hs.config.worker.worker_app is None:
+        DeviceRestServlet(hs).register(http_server)
+        DehydratedDeviceServlet(hs).register(http_server)
+        ClaimDehydratedDeviceServlet(hs).register(http_server)