summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_server.py70
1 files changed, 58 insertions, 12 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d7ce333822..8eddb3bf2c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Dict
 
 import six
 from six import iteritems
@@ -22,6 +23,7 @@ from six import iteritems
 from canonicaljson import json
 from prometheus_client import Counter
 
+from twisted.internet import defer
 from twisted.internet.abstract import isIPAddress
 from twisted.python import failure
 
@@ -41,7 +43,11 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
 from synapse.http.endpoint import parse_server_name
-from synapse.logging.context import nested_logging_context
+from synapse.logging.context import (
+    make_deferred_yieldable,
+    nested_logging_context,
+    run_in_background,
+)
 from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
 from synapse.logging.utils import log_function
 from synapse.replication.http.federation import (
@@ -49,7 +55,7 @@ from synapse.replication.http.federation import (
     ReplicationGetQueryRestServlet,
 )
 from synapse.types import get_domain_from_id
-from synapse.util import glob_to_regex
+from synapse.util import glob_to_regex, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
 
@@ -160,6 +166,43 @@ class FederationServer(FederationBase):
             )
             return 400, response
 
+        # We process PDUs and EDUs in parallel. This is important as we don't
+        # want to block things like to device messages from reaching clients
+        # behind the potentially expensive handling of PDUs.
+        pdu_results, _ = await make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self._handle_pdus_in_txn, origin, transaction, request_time
+                    ),
+                    run_in_background(self._handle_edus_in_txn, origin, transaction),
+                ],
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
+        )
+
+        response = {"pdus": pdu_results}
+
+        logger.debug("Returning: %s", str(response))
+
+        await self.transaction_actions.set_response(origin, transaction, 200, response)
+        return 200, response
+
+    async def _handle_pdus_in_txn(
+        self, origin: str, transaction: Transaction, request_time: int
+    ) -> Dict[str, dict]:
+        """Process the PDUs in a received transaction.
+
+        Args:
+            origin: the server making the request
+            transaction: incoming transaction
+            request_time: timestamp that the HTTP request arrived at
+
+        Returns:
+            A map from event ID of a processed PDU to any errors we should
+            report back to the sending server.
+        """
+
         received_pdus_counter.inc(len(transaction.pdus))
 
         origin_host, _ = parse_server_name(origin)
@@ -250,20 +293,23 @@ class FederationServer(FederationBase):
             process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
         )
 
-        if hasattr(transaction, "edus"):
-            for edu in (Edu(**x) for x in transaction.edus):
-                await self.received_edu(origin, edu.edu_type, edu.content)
+        return pdu_results
 
-        response = {"pdus": pdu_results}
+    async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
+        """Process the EDUs in a received transaction.
+        """
 
-        logger.debug("Returning: %s", str(response))
+        async def _process_edu(edu_dict):
+            received_edus_counter.inc()
 
-        await self.transaction_actions.set_response(origin, transaction, 200, response)
-        return 200, response
+            edu = Edu(**edu_dict)
+            await self.registry.on_edu(edu.edu_type, origin, edu.content)
 
-    async def received_edu(self, origin, edu_type, content):
-        received_edus_counter.inc()
-        await self.registry.on_edu(edu_type, origin, content)
+        await concurrently_execute(
+            _process_edu,
+            getattr(transaction, "edus", []),
+            TRANSACTION_CONCURRENCY_LIMIT,
+        )
 
     async def on_context_state_request(self, origin, room_id, event_id):
         origin_host, _ = parse_server_name(origin)