summary refs log tree commit diff
path: root/synapse/replication/http/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/http/federation.py')
-rw-r--r--synapse/replication/http/federation.py65
1 files changed, 42 insertions, 23 deletions
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 5ed535c90d..d529c8a19f 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -13,17 +13,22 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, List, Tuple
 
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.events import make_event_from_dict
+from twisted.web.server import Request
+
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events import EventBase, make_event_from_dict
 from synapse.events.snapshot import EventContext
+from synapse.http.server import HttpServer
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
+from synapse.types import JsonDict
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
@@ -69,14 +74,18 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         self.federation_event_handler = hs.get_federation_event_handler()
 
     @staticmethod
-    async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
+    async def _serialize_payload(  # type: ignore[override]
+        store: "DataStore",
+        room_id: str,
+        event_and_contexts: List[Tuple[EventBase, EventContext]],
+        backfilled: bool,
+    ) -> JsonDict:
         """
         Args:
             store
-            room_id (str)
-            event_and_contexts (list[tuple[FrozenEvent, EventContext]])
-            backfilled (bool): Whether or not the events are the result of
-                backfilling
+            room_id
+            event_and_contexts
+            backfilled: Whether or not the events are the result of backfilling
         """
         event_payloads = []
         for event, context in event_and_contexts:
@@ -102,7 +111,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
 
         return payload
 
-    async def _handle_request(self, request):
+    async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]:  # type: ignore[override]
         with Measure(self.clock, "repl_fed_send_events_parse"):
             content = parse_json_object_from_request(request)
 
@@ -163,10 +172,14 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
         self.registry = hs.get_federation_registry()
 
     @staticmethod
-    async def _serialize_payload(edu_type, origin, content):
+    async def _serialize_payload(  # type: ignore[override]
+        edu_type: str, origin: str, content: JsonDict
+    ) -> JsonDict:
         return {"origin": origin, "content": content}
 
-    async def _handle_request(self, request, edu_type):
+    async def _handle_request(  # type: ignore[override]
+        self, request: Request, edu_type: str
+    ) -> Tuple[int, JsonDict]:
         with Measure(self.clock, "repl_fed_send_edu_parse"):
             content = parse_json_object_from_request(request)
 
@@ -175,9 +188,9 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
 
         logger.info("Got %r edu from %s", edu_type, origin)
 
-        result = await self.registry.on_edu(edu_type, origin, edu_content)
+        await self.registry.on_edu(edu_type, origin, edu_content)
 
-        return 200, result
+        return 200, {}
 
 
 class ReplicationGetQueryRestServlet(ReplicationEndpoint):
@@ -206,15 +219,17 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
         self.registry = hs.get_federation_registry()
 
     @staticmethod
-    async def _serialize_payload(query_type, args):
+    async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict:  # type: ignore[override]
         """
         Args:
-            query_type (str)
-            args (dict): The arguments received for the given query type
+            query_type
+            args: The arguments received for the given query type
         """
         return {"args": args}
 
-    async def _handle_request(self, request, query_type):
+    async def _handle_request(  # type: ignore[override]
+        self, request: Request, query_type: str
+    ) -> Tuple[int, JsonDict]:
         with Measure(self.clock, "repl_fed_query_parse"):
             content = parse_json_object_from_request(request)
 
@@ -248,14 +263,16 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
         self.store = hs.get_datastore()
 
     @staticmethod
-    async def _serialize_payload(room_id, args):
+    async def _serialize_payload(room_id: str) -> JsonDict:  # type: ignore[override]
         """
         Args:
-            room_id (str)
+            room_id
         """
         return {}
 
-    async def _handle_request(self, request, room_id):
+    async def _handle_request(  # type: ignore[override]
+        self, request: Request, room_id: str
+    ) -> Tuple[int, JsonDict]:
         await self.store.clean_room_for_join(room_id)
 
         return 200, {}
@@ -283,17 +300,19 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
         self.store = hs.get_datastore()
 
     @staticmethod
-    async def _serialize_payload(room_id, room_version):
+    async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict:  # type: ignore[override]
         return {"room_version": room_version.identifier}
 
-    async def _handle_request(self, request, room_id):
+    async def _handle_request(  # type: ignore[override]
+        self, request: Request, room_id: str
+    ) -> Tuple[int, JsonDict]:
         content = parse_json_object_from_request(request)
         room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
         await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
         return 200, {}
 
 
-def register_servlets(hs: "HomeServer", http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ReplicationFederationSendEventsRestServlet(hs).register(http_server)
     ReplicationFederationSendEduRestServlet(hs).register(http_server)
     ReplicationGetQueryRestServlet(hs).register(http_server)