diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index f8e368f81b..98caf2a1a4 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -112,10 +112,11 @@ class FederationServer(FederationBase):
# with FederationHandlerRegistry.
hs.get_directory_handler()
- self._federation_ratelimiter = hs.get_federation_ratelimiter()
-
self._server_linearizer = Linearizer("fed_server")
- self._transaction_linearizer = Linearizer("fed_txn_handler")
+
+ # origins that we are currently processing a transaction from.
+ # a dict from origin to txn id.
+ self._active_transactions = {} # type: Dict[str, str]
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
@@ -169,6 +170,33 @@ class FederationServer(FederationBase):
logger.debug("[%s] Got transaction", transaction_id)
+ # Reject malformed transactions early: reject if too many PDUs/EDUs
+ if len(transaction.pdus) > 50 or ( # type: ignore
+ hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
+ ):
+ logger.info("Transaction PDU or EDU count too large. Returning 400")
+ return 400, {}
+
+ # we only process one transaction from each origin at a time. We need to do
+ # this check here, rather than in _on_incoming_transaction_inner so that we
+ # don't cache the rejection in _transaction_resp_cache (so that if the txn
+ # arrives again later, we can process it).
+ current_transaction = self._active_transactions.get(origin)
+ if current_transaction and current_transaction != transaction_id:
+ logger.warning(
+ "Received another txn %s from %s while still processing %s",
+ transaction_id,
+ origin,
+ current_transaction,
+ )
+ return 429, {
+ "errcode": Codes.UNKNOWN,
+ "error": "Too many concurrent transactions",
+ }
+
+ # CRITICAL SECTION: we must now not await until we populate _active_transactions
+ # in _on_incoming_transaction_inner.
+
# We wrap in a ResponseCache so that we de-duplicate retried
# transactions.
return await self._transaction_resp_cache.wrap(
@@ -182,26 +210,18 @@ class FederationServer(FederationBase):
async def _on_incoming_transaction_inner(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
- # Use a linearizer to ensure that transactions from a remote are
- # processed in order.
- with await self._transaction_linearizer.queue(origin):
- # We rate limit here *after* we've queued up the incoming requests,
- # so that we don't fill up the ratelimiter with blocked requests.
- #
- # This is important as the ratelimiter allows N concurrent requests
- # at a time, and only starts ratelimiting if there are more requests
- # than that being processed at a time. If we queued up requests in
- # the linearizer/response cache *after* the ratelimiting then those
- # queued up requests would count as part of the allowed limit of N
- # concurrent requests.
- with self._federation_ratelimiter.ratelimit(origin) as d:
- await d
-
- result = await self._handle_incoming_transaction(
- origin, transaction, request_time
- )
+ # CRITICAL SECTION: the first thing we must do (before awaiting) is
+ # add an entry to _active_transactions.
+ assert origin not in self._active_transactions
+ self._active_transactions[origin] = transaction.transaction_id # type: ignore
- return result
+ try:
+ result = await self._handle_incoming_transaction(
+ origin, transaction, request_time
+ )
+ return result
+ finally:
+ del self._active_transactions[origin]
async def _handle_incoming_transaction(
self, origin: str, transaction: Transaction, request_time: int
@@ -227,19 +247,6 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
- # Reject if PDU count > 50 or EDU count > 100
- if len(transaction.pdus) > 50 or ( # type: ignore
- hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
- ):
-
- logger.info("Transaction PDU or EDU count too large. Returning 400")
-
- response = {}
- await self.transaction_actions.set_response(
- origin, transaction, 400, response
- )
- return 400, response
-
# We process PDUs and EDUs in parallel. This is important as we don't
# want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs.
@@ -335,34 +342,41 @@ class FederationServer(FederationBase):
# impose a limit to avoid going too crazy with ram/cpu.
async def process_pdus_for_room(room_id: str):
- logger.debug("Processing PDUs for %s", room_id)
- try:
- await self.check_server_matches_acl(origin_host, room_id)
- except AuthError as e:
- logger.warning("Ignoring PDUs for room %s from banned server", room_id)
- for pdu in pdus_by_room[room_id]:
- event_id = pdu.event_id
- pdu_results[event_id] = e.error_dict()
- return
+ with nested_logging_context(room_id):
+ logger.debug("Processing PDUs for %s", room_id)
- for pdu in pdus_by_room[room_id]:
- event_id = pdu.event_id
- with pdu_process_time.time():
- with nested_logging_context(event_id):
- try:
- await self._handle_received_pdu(origin, pdu)
- pdu_results[event_id] = {}
- except FederationError as e:
- logger.warning("Error handling PDU %s: %s", event_id, e)
- pdu_results[event_id] = {"error": str(e)}
- except Exception as e:
- f = failure.Failure()
- pdu_results[event_id] = {"error": str(e)}
- logger.error(
- "Failed to handle PDU %s",
- event_id,
- exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
- )
+ try:
+ await self.check_server_matches_acl(origin_host, room_id)
+ except AuthError as e:
+ logger.warning(
+ "Ignoring PDUs for room %s from banned server", room_id
+ )
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ pdu_results[event_id] = e.error_dict()
+ return
+
+ for pdu in pdus_by_room[room_id]:
+ pdu_results[pdu.event_id] = await process_pdu(pdu)
+
+ async def process_pdu(pdu: EventBase) -> JsonDict:
+ event_id = pdu.event_id
+ with pdu_process_time.time():
+ with nested_logging_context(event_id):
+ try:
+ await self._handle_received_pdu(origin, pdu)
+ return {}
+ except FederationError as e:
+ logger.warning("Error handling PDU %s: %s", event_id, e)
+ return {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ logger.error(
+ "Failed to handle PDU %s",
+ event_id,
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
+ )
+ return {"error": str(e)}
await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
@@ -447,7 +461,7 @@ class FederationServer(FederationBase):
async def _on_state_ids_request_compute(self, room_id, event_id):
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
- auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
+ auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
@@ -460,7 +474,9 @@ class FederationServer(FederationBase):
else:
pdus = (await self.state.get_current_state(room_id)).values()
- auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
+ auth_chain = await self.store.get_auth_chain(
+ room_id, [pdu.event_id for pdu in pdus]
+ )
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
@@ -864,7 +880,9 @@ class FederationHandlerRegistry:
self.edu_handlers = (
{}
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
- self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
+ self.query_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
# Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new
@@ -898,7 +916,7 @@ class FederationHandlerRegistry:
self.edu_handlers[edu_type] = handler
def register_query_handler(
- self, query_type: str, handler: Callable[[dict], defer.Deferred]
+ self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
@@ -975,7 +993,7 @@ class FederationHandlerRegistry:
# Oh well, let's just log and move on.
logger.warning("No handler registered for EDU type %s", edu_type)
- async def on_query(self, query_type: str, args: dict):
+ async def on_query(self, query_type: str, args: dict) -> JsonDict:
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)
|