summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/federation_server.py77
1 files changed, 42 insertions, 35 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 89fe6def4b..db6e49dbca 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -112,10 +112,11 @@ class FederationServer(FederationBase):
         # with FederationHandlerRegistry.
         hs.get_directory_handler()
 
-        self._federation_ratelimiter = hs.get_federation_ratelimiter()
-
         self._server_linearizer = Linearizer("fed_server")
-        self._transaction_linearizer = Linearizer("fed_txn_handler")
+
+        # origins that we are currently processing a transaction from.
+        # a dict from origin to txn id.
+        self._active_transactions = {}  # type: Dict[str, str]
 
         # We cache results for transaction with the same ID
         self._transaction_resp_cache = ResponseCache(
@@ -169,6 +170,33 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Got transaction", transaction_id)
 
+        # Reject malformed transactions early: reject if too many PDUs/EDUs
+        if len(transaction.pdus) > 50 or (  # type: ignore
+            hasattr(transaction, "edus") and len(transaction.edus) > 100  # type: ignore
+        ):
+            logger.info("Transaction PDU or EDU count too large. Returning 400")
+            return 400, {}
+
+        # we only process one transaction from each origin at a time. We need to do
+        # this check here, rather than in _on_incoming_transaction_inner so that we
+        # don't cache the rejection in _transaction_resp_cache (so that if the txn
+        # arrives again later, we can process it).
+        current_transaction = self._active_transactions.get(origin)
+        if current_transaction and current_transaction != transaction_id:
+            logger.warning(
+                "Received another txn %s from %s while still processing %s",
+                transaction_id,
+                origin,
+                current_transaction,
+            )
+            return 429, {
+                "errcode": Codes.UNKNOWN,
+                "error": "Too many concurrent transactions",
+            }
+
+        # CRITICAL SECTION: we must now not await until we populate _active_transactions
+        # in _on_incoming_transaction_inner.
+
         # We wrap in a ResponseCache so that we de-duplicate retried
         # transactions.
         return await self._transaction_resp_cache.wrap(
@@ -182,26 +210,18 @@ class FederationServer(FederationBase):
     async def _on_incoming_transaction_inner(
         self, origin: str, transaction: Transaction, request_time: int
     ) -> Tuple[int, Dict[str, Any]]:
-        # Use a linearizer to ensure that transactions from a remote are
-        # processed in order.
-        with await self._transaction_linearizer.queue(origin):
-            # We rate limit here *after* we've queued up the incoming requests,
-            # so that we don't fill up the ratelimiter with blocked requests.
-            #
-            # This is important as the ratelimiter allows N concurrent requests
-            # at a time, and only starts ratelimiting if there are more requests
-            # than that being processed at a time. If we queued up requests in
-            # the linearizer/response cache *after* the ratelimiting then those
-            # queued up requests would count as part of the allowed limit of N
-            # concurrent requests.
-            with self._federation_ratelimiter.ratelimit(origin) as d:
-                await d
-
-                result = await self._handle_incoming_transaction(
-                    origin, transaction, request_time
-                )
+        # CRITICAL SECTION: the first thing we must do (before awaiting) is
+        # add an entry to _active_transactions.
+        assert origin not in self._active_transactions
+        self._active_transactions[origin] = transaction.transaction_id  # type: ignore
 
-        return result
+        try:
+            result = await self._handle_incoming_transaction(
+                origin, transaction, request_time
+            )
+            return result
+        finally:
+            del self._active_transactions[origin]
 
     async def _handle_incoming_transaction(
         self, origin: str, transaction: Transaction, request_time: int
@@ -227,19 +247,6 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)  # type: ignore
 
-        # Reject if PDU count > 50 or EDU count > 100
-        if len(transaction.pdus) > 50 or (  # type: ignore
-            hasattr(transaction, "edus") and len(transaction.edus) > 100  # type: ignore
-        ):
-
-            logger.info("Transaction PDU or EDU count too large. Returning 400")
-
-            response = {}
-            await self.transaction_actions.set_response(
-                origin, transaction, 400, response
-            )
-            return 400, response
-
         # We process PDUs and EDUs in parallel. This is important as we don't
         # want to block things like to device messages from reaching clients
         # behind the potentially expensive handling of PDUs.