summary refs log tree commit diff
path: root/synapse/federation/federation_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_server.py')
-rw-r--r--synapse/federation/federation_server.py152
1 files changed, 85 insertions, 67 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py

index f8e368f81b..98caf2a1a4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -112,10 +112,11 @@ class FederationServer(FederationBase): # with FederationHandlerRegistry. hs.get_directory_handler() - self._federation_ratelimiter = hs.get_federation_ratelimiter() - self._server_linearizer = Linearizer("fed_server") - self._transaction_linearizer = Linearizer("fed_txn_handler") + + # origins that we are currently processing a transaction from. + # a dict from origin to txn id. + self._active_transactions = {} # type: Dict[str, str] # We cache results for transaction with the same ID self._transaction_resp_cache = ResponseCache( @@ -169,6 +170,33 @@ class FederationServer(FederationBase): logger.debug("[%s] Got transaction", transaction_id) + # Reject malformed transactions early: reject if too many PDUs/EDUs + if len(transaction.pdus) > 50 or ( # type: ignore + hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore + ): + logger.info("Transaction PDU or EDU count too large. Returning 400") + return 400, {} + + # we only process one transaction from each origin at a time. We need to do + # this check here, rather than in _on_incoming_transaction_inner so that we + # don't cache the rejection in _transaction_resp_cache (so that if the txn + # arrives again later, we can process it). + current_transaction = self._active_transactions.get(origin) + if current_transaction and current_transaction != transaction_id: + logger.warning( + "Received another txn %s from %s while still processing %s", + transaction_id, + origin, + current_transaction, + ) + return 429, { + "errcode": Codes.UNKNOWN, + "error": "Too many concurrent transactions", + } + + # CRITICAL SECTION: we must now not await until we populate _active_transactions + # in _on_incoming_transaction_inner. + # We wrap in a ResponseCache so that we de-duplicate retried # transactions. return await self._transaction_resp_cache.wrap( @@ -182,26 +210,18 @@ class FederationServer(FederationBase): async def _on_incoming_transaction_inner( self, origin: str, transaction: Transaction, request_time: int ) -> Tuple[int, Dict[str, Any]]: - # Use a linearizer to ensure that transactions from a remote are - # processed in order. - with await self._transaction_linearizer.queue(origin): - # We rate limit here *after* we've queued up the incoming requests, - # so that we don't fill up the ratelimiter with blocked requests. - # - # This is important as the ratelimiter allows N concurrent requests - # at a time, and only starts ratelimiting if there are more requests - # than that being processed at a time. If we queued up requests in - # the linearizer/response cache *after* the ratelimiting then those - # queued up requests would count as part of the allowed limit of N - # concurrent requests. - with self._federation_ratelimiter.ratelimit(origin) as d: - await d - - result = await self._handle_incoming_transaction( - origin, transaction, request_time - ) + # CRITICAL SECTION: the first thing we must do (before awaiting) is + # add an entry to _active_transactions. + assert origin not in self._active_transactions + self._active_transactions[origin] = transaction.transaction_id # type: ignore - return result + try: + result = await self._handle_incoming_transaction( + origin, transaction, request_time + ) + return result + finally: + del self._active_transactions[origin] async def _handle_incoming_transaction( self, origin: str, transaction: Transaction, request_time: int @@ -227,19 +247,6 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore - # Reject if PDU count > 50 or EDU count > 100 - if len(transaction.pdus) > 50 or ( # type: ignore - hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore - ): - - logger.info("Transaction PDU or EDU count too large. Returning 400") - - response = {} - await self.transaction_actions.set_response( - origin, transaction, 400, response - ) - return 400, response - # We process PDUs and EDUs in parallel. This is important as we don't # want to block things like to device messages from reaching clients # behind the potentially expensive handling of PDUs. @@ -335,34 +342,41 @@ class FederationServer(FederationBase): # impose a limit to avoid going too crazy with ram/cpu. async def process_pdus_for_room(room_id: str): - logger.debug("Processing PDUs for %s", room_id) - try: - await self.check_server_matches_acl(origin_host, room_id) - except AuthError as e: - logger.warning("Ignoring PDUs for room %s from banned server", room_id) - for pdu in pdus_by_room[room_id]: - event_id = pdu.event_id - pdu_results[event_id] = e.error_dict() - return + with nested_logging_context(room_id): + logger.debug("Processing PDUs for %s", room_id) - for pdu in pdus_by_room[room_id]: - event_id = pdu.event_id - with pdu_process_time.time(): - with nested_logging_context(event_id): - try: - await self._handle_received_pdu(origin, pdu) - pdu_results[event_id] = {} - except FederationError as e: - logger.warning("Error handling PDU %s: %s", event_id, e) - pdu_results[event_id] = {"error": str(e)} - except Exception as e: - f = failure.Failure() - pdu_results[event_id] = {"error": str(e)} - logger.error( - "Failed to handle PDU %s", - event_id, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore - ) + try: + await self.check_server_matches_acl(origin_host, room_id) + except AuthError as e: + logger.warning( + "Ignoring PDUs for room %s from banned server", room_id + ) + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + pdu_results[event_id] = e.error_dict() + return + + for pdu in pdus_by_room[room_id]: + pdu_results[pdu.event_id] = await process_pdu(pdu) + + async def process_pdu(pdu: EventBase) -> JsonDict: + event_id = pdu.event_id + with pdu_process_time.time(): + with nested_logging_context(event_id): + try: + await self._handle_received_pdu(origin, pdu) + return {} + except FederationError as e: + logger.warning("Error handling PDU %s: %s", event_id, e) + return {"error": str(e)} + except Exception as e: + f = failure.Failure() + logger.error( + "Failed to handle PDU %s", + event_id, + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + ) + return {"error": str(e)} await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT @@ -447,7 +461,7 @@ class FederationServer(FederationBase): async def _on_state_ids_request_compute(self, room_id, event_id): state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) + auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute( @@ -460,7 +474,9 @@ class FederationServer(FederationBase): else: pdus = (await self.state.get_current_state(room_id)).values() - auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + auth_chain = await self.store.get_auth_chain( + room_id, [pdu.event_id for pdu in pdus] + ) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], @@ -864,7 +880,9 @@ class FederationHandlerRegistry: self.edu_handlers = ( {} ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] - self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] + self.query_handlers = ( + {} + ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]] # Map from type to instance names that we should route EDU handling to. # We randomly choose one instance from the list to route to for each new @@ -898,7 +916,7 @@ class FederationHandlerRegistry: self.edu_handlers[edu_type] = handler def register_query_handler( - self, query_type: str, handler: Callable[[dict], defer.Deferred] + self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] ): """Sets the handler callable that will be used to handle an incoming federation query of the given type. @@ -975,7 +993,7 @@ class FederationHandlerRegistry: # Oh well, let's just log and move on. logger.warning("No handler registered for EDU type %s", edu_type) - async def on_query(self, query_type: str, args: dict): + async def on_query(self, query_type: str, args: dict) -> JsonDict: handler = self.query_handlers.get(query_type) if handler: return await handler(args)