diff options
Diffstat (limited to 'synapse/replication/http/federation.py')
-rw-r--r-- | synapse/replication/http/federation.py | 65 |
1 files changed, 42 insertions, 23 deletions
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 5ed535c90d..d529c8a19f 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -13,17 +13,22 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Tuple -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events import make_event_from_dict +from twisted.web.server import Request + +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext +from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -69,14 +74,18 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_event_handler = hs.get_federation_event_handler() @staticmethod - async def _serialize_payload(store, room_id, event_and_contexts, backfilled): + async def _serialize_payload( # type: ignore[override] + store: "DataStore", + room_id: str, + event_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, + ) -> JsonDict: """ Args: store - room_id (str) - event_and_contexts (list[tuple[FrozenEvent, EventContext]]) - backfilled (bool): Whether or not the events are the result of - backfilling + room_id + event_and_contexts + backfilled: Whether or not the events are the result of backfilling """ event_payloads = [] for event, context in event_and_contexts: @@ -102,7 +111,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): return payload - async def _handle_request(self, request): + async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override] with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) @@ -163,10 +172,14 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): self.registry = hs.get_federation_registry() @staticmethod - async def _serialize_payload(edu_type, origin, content): + async def _serialize_payload( # type: ignore[override] + edu_type: str, origin: str, content: JsonDict + ) -> JsonDict: return {"origin": origin, "content": content} - async def _handle_request(self, request, edu_type): + async def _handle_request( # type: ignore[override] + self, request: Request, edu_type: str + ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_fed_send_edu_parse"): content = parse_json_object_from_request(request) @@ -175,9 +188,9 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): logger.info("Got %r edu from %s", edu_type, origin) - result = await self.registry.on_edu(edu_type, origin, edu_content) + await self.registry.on_edu(edu_type, origin, edu_content) - return 200, result + return 200, {} class ReplicationGetQueryRestServlet(ReplicationEndpoint): @@ -206,15 +219,17 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): self.registry = hs.get_federation_registry() @staticmethod - async def _serialize_payload(query_type, args): + async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict: # type: ignore[override] """ Args: - query_type (str) - args (dict): The arguments received for the given query type + query_type + args: The arguments received for the given query type """ return {"args": args} - async def _handle_request(self, request, query_type): + async def _handle_request( # type: ignore[override] + self, request: Request, query_type: str + ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_fed_query_parse"): content = parse_json_object_from_request(request) @@ -248,14 +263,16 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() @staticmethod - async def _serialize_payload(room_id, args): + async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override] """ Args: - room_id (str) + room_id """ return {} - async def _handle_request(self, request, room_id): + async def _handle_request( # type: ignore[override] + self, request: Request, room_id: str + ) -> Tuple[int, JsonDict]: await self.store.clean_room_for_join(room_id) return 200, {} @@ -283,17 +300,19 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() @staticmethod - async def _serialize_payload(room_id, room_version): + async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override] return {"room_version": room_version.identifier} - async def _handle_request(self, request, room_id): + async def _handle_request( # type: ignore[override] + self, request: Request, room_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] await self.store.maybe_store_room_on_outlier_membership(room_id, room_version) return 200, {} -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationFederationSendEventsRestServlet(hs).register(http_server) ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) |