diff options
Diffstat (limited to 'synapse/federation/federation_server.py')
-rw-r--r-- | synapse/federation/federation_server.py | 45 |
1 files changed, 24 insertions, 21 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 24329dd0e3..4b6ab470d0 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -22,7 +22,6 @@ from typing import ( Callable, Dict, List, - Match, Optional, Tuple, Union, @@ -50,6 +49,7 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.endpoint import parse_server_name +from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -100,10 +100,15 @@ class FederationServer(FederationBase): super().__init__(hs) self.auth = hs.get_auth() - self.handler = hs.get_handlers().federation_handler + self.handler = hs.get_federation_handler() self.state = hs.get_state_handler() self.device_handler = hs.get_device_handler() + + # Ensure the following handlers are loaded since they register callbacks + # with FederationHandlerRegistry. + hs.get_directory_handler() + self._federation_ratelimiter = hs.get_federation_ratelimiter() self._server_linearizer = Linearizer("fed_server") @@ -112,7 +117,7 @@ class FederationServer(FederationBase): # We cache results for transaction with the same ID self._transaction_resp_cache = ResponseCache( hs, "fed_txn_handler", timeout_ms=30000 - ) + ) # type: ResponseCache[Tuple[str, str]] self.transaction_actions = TransactionActions(self.store) @@ -120,10 +125,12 @@ class FederationServer(FederationBase): # We cache responses to state queries, as they take a while and often # come in waves. - self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) + self._state_resp_cache = ResponseCache( + hs, "state_resp", timeout_ms=30000 + ) # type: ResponseCache[Tuple[str, str]] self._state_ids_resp_cache = ResponseCache( hs, "state_ids_resp", timeout_ms=30000 - ) + ) # type: ResponseCache[Tuple[str, str]] self._federation_metrics_domains = ( hs.get_config().federation.federation_metrics_domains @@ -385,7 +392,7 @@ class FederationServer(FederationBase): TRANSACTION_CONCURRENCY_LIMIT, ) - async def on_context_state_request( + async def on_room_state_request( self, origin: str, room_id: str, event_id: str ) -> Tuple[int, Dict[str, Any]]: origin_host, _ = parse_server_name(origin) @@ -508,11 +515,12 @@ class FederationServer(FederationBase): return {"event": ret_pdu.get_pdu_json(time_now)} async def on_send_join_request( - self, origin: str, content: JsonDict, room_id: str + self, origin: str, content: JsonDict ) -> Dict[str, Any]: logger.debug("on_send_join_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + assert_params_in_dict(content, ["room_id"]) + room_version = await self.store.get_room_version(content["room_id"]) pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) @@ -541,12 +549,11 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - async def on_send_leave_request( - self, origin: str, content: JsonDict, room_id: str - ) -> dict: + async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict: logger.debug("on_send_leave_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + assert_params_in_dict(content, ["room_id"]) + room_version = await self.store.get_room_version(content["room_id"]) pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) @@ -742,12 +749,8 @@ class FederationServer(FederationBase): ) return ret - async def on_exchange_third_party_invite_request( - self, room_id: str, event_dict: Dict - ): - ret = await self.handler.on_exchange_third_party_invite_request( - room_id, event_dict - ) + async def on_exchange_third_party_invite_request(self, event_dict: Dict): + ret = await self.handler.on_exchange_third_party_invite_request(event_dict) return ret async def check_server_matches_acl(self, server_name: str, room_id: str): @@ -825,14 +828,14 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: return False -def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: +def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool: if not isinstance(acl_entry, str): logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) return False regex = glob_to_regex(acl_entry) - return regex.match(server_name) + return bool(regex.match(server_name)) class FederationHandlerRegistry: @@ -862,7 +865,7 @@ class FederationHandlerRegistry: self._edu_type_to_instance = {} # type: Dict[str, str] def register_edu_handler( - self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]] + self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] ): """Sets the handler callable that will be used to handle an incoming federation EDU of the given type. |