summary refs log tree commit diff
path: root/synapse/rest/client/room_batch.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/room_batch.py')
-rw-r--r--synapse/rest/client/room_batch.py31
1 files changed, 22 insertions, 9 deletions
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 3172aba605..ed96978448 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -14,10 +14,14 @@
 
 import logging
 import re
+from typing import TYPE_CHECKING, Awaitable, List, Tuple
+
+from twisted.web.server import Request
 
 from synapse.api.constants import EventContentFields, EventTypes
 from synapse.api.errors import AuthError, Codes, SynapseError
 from synapse.appservice import ApplicationService
+from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
@@ -25,10 +29,14 @@ from synapse.http.servlet import (
     parse_string,
     parse_strings_from_args,
 )
+from synapse.http.site import SynapseRequest
 from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
 from synapse.util.stringutils import random_string
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -66,7 +74,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
         ),
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.store = hs.get_datastore()
@@ -76,7 +84,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.txns = HttpTransactionCache(hs)
 
-    async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
+    async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int:
         (
             most_recent_prev_event_id,
             most_recent_prev_event_depth,
@@ -118,7 +126,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
 
     def _create_insertion_event_dict(
         self, sender: str, room_id: str, origin_server_ts: int
-    ):
+    ) -> JsonDict:
         """Creates an event dict for an "insertion" event with the proper fields
         and a random chunk ID.
 
@@ -128,7 +136,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
             origin_server_ts: Timestamp when the event was sent
 
         Returns:
-            Tuple of event ID and stream ordering position
+            The new event dictionary to insert.
         """
 
         next_chunk_id = random_string(8)
@@ -164,7 +172,9 @@ class RoomBatchSendEventRestServlet(RestServlet):
 
         return create_requester(user_id, app_service=app_service)
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
 
         if not requester.app_service:
@@ -176,6 +186,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
         body = parse_json_object_from_request(request)
         assert_params_in_dict(body, ["state_events_at_start", "events"])
 
+        assert request.args is not None
         prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
         chunk_id_from_query = parse_string(request, "chunk_id")
 
@@ -425,16 +436,18 @@ class RoomBatchSendEventRestServlet(RestServlet):
             ],
         }
 
-    def on_GET(self, request, room_id):
+    def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
         return 501, "Not implemented"
 
-    def on_PUT(self, request, room_id):
+    def on_PUT(
+        self, request: SynapseRequest, room_id: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         return self.txns.fetch_or_execute_request(
             request, self.on_POST, request, room_id
         )
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     msc2716_enabled = hs.config.experimental.msc2716_enabled
 
     if msc2716_enabled: