diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d7ce333822..8eddb3bf2c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Dict
import six
from six import iteritems
@@ -22,6 +23,7 @@ from six import iteritems
from canonicaljson import json
from prometheus_client import Counter
+from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
@@ -41,7 +43,11 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
-from synapse.logging.context import nested_logging_context
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ nested_logging_context,
+ run_in_background,
+)
from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
@@ -49,7 +55,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.types import get_domain_from_id
-from synapse.util import glob_to_regex
+from synapse.util import glob_to_regex, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -160,6 +166,43 @@ class FederationServer(FederationBase):
)
return 400, response
+ # We process PDUs and EDUs in parallel. This is important as we don't
+ # want to block things like to device messages from reaching clients
+ # behind the potentially expensive handling of PDUs.
+ pdu_results, _ = await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self._handle_pdus_in_txn, origin, transaction, request_time
+ ),
+ run_in_background(self._handle_edus_in_txn, origin, transaction),
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ response = {"pdus": pdu_results}
+
+ logger.debug("Returning: %s", str(response))
+
+ await self.transaction_actions.set_response(origin, transaction, 200, response)
+ return 200, response
+
+ async def _handle_pdus_in_txn(
+ self, origin: str, transaction: Transaction, request_time: int
+ ) -> Dict[str, dict]:
+ """Process the PDUs in a received transaction.
+
+ Args:
+ origin: the server making the request
+ transaction: incoming transaction
+ request_time: timestamp that the HTTP request arrived at
+
+ Returns:
+ A map from event ID of a processed PDU to any errors we should
+ report back to the sending server.
+ """
+
received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(origin)
@@ -250,20 +293,23 @@ class FederationServer(FederationBase):
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
- if hasattr(transaction, "edus"):
- for edu in (Edu(**x) for x in transaction.edus):
- await self.received_edu(origin, edu.edu_type, edu.content)
+ return pdu_results
- response = {"pdus": pdu_results}
+ async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
+ """Process the EDUs in a received transaction.
+ """
- logger.debug("Returning: %s", str(response))
+ async def _process_edu(edu_dict):
+ received_edus_counter.inc()
- await self.transaction_actions.set_response(origin, transaction, 200, response)
- return 200, response
+ edu = Edu(**edu_dict)
+ await self.registry.on_edu(edu.edu_type, origin, edu.content)
- async def received_edu(self, origin, edu_type, content):
- received_edus_counter.inc()
- await self.registry.on_edu(edu_type, origin, content)
+ await concurrently_execute(
+ _process_edu,
+ getattr(transaction, "edus", []),
+ TRANSACTION_CONCURRENCY_LIMIT,
+ )
async def on_context_state_request(self, origin, room_id, event_id):
origin_host, _ = parse_server_name(origin)
|