diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d7ce333822..8eddb3bf2c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Dict
import six
from six import iteritems
@@ -22,6 +23,7 @@ from six import iteritems
from canonicaljson import json
from prometheus_client import Counter
+from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
@@ -41,7 +43,11 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
-from synapse.logging.context import nested_logging_context
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ nested_logging_context,
+ run_in_background,
+)
from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
@@ -49,7 +55,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.types import get_domain_from_id
-from synapse.util import glob_to_regex
+from synapse.util import glob_to_regex, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -160,6 +166,43 @@ class FederationServer(FederationBase):
)
return 400, response
+ # We process PDUs and EDUs in parallel. This is important as we don't
+ # want to block things like to device messages from reaching clients
+ # behind the potentially expensive handling of PDUs.
+ pdu_results, _ = await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self._handle_pdus_in_txn, origin, transaction, request_time
+ ),
+ run_in_background(self._handle_edus_in_txn, origin, transaction),
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ response = {"pdus": pdu_results}
+
+ logger.debug("Returning: %s", str(response))
+
+ await self.transaction_actions.set_response(origin, transaction, 200, response)
+ return 200, response
+
+ async def _handle_pdus_in_txn(
+ self, origin: str, transaction: Transaction, request_time: int
+ ) -> Dict[str, dict]:
+ """Process the PDUs in a received transaction.
+
+ Args:
+ origin: the server making the request
+ transaction: incoming transaction
+ request_time: timestamp that the HTTP request arrived at
+
+ Returns:
+ A map from event ID of a processed PDU to any errors we should
+ report back to the sending server.
+ """
+
received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(origin)
@@ -250,20 +293,23 @@ class FederationServer(FederationBase):
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
- if hasattr(transaction, "edus"):
- for edu in (Edu(**x) for x in transaction.edus):
- await self.received_edu(origin, edu.edu_type, edu.content)
+ return pdu_results
- response = {"pdus": pdu_results}
+ async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
+ """Process the EDUs in a received transaction.
+ """
- logger.debug("Returning: %s", str(response))
+ async def _process_edu(edu_dict):
+ received_edus_counter.inc()
- await self.transaction_actions.set_response(origin, transaction, 200, response)
- return 200, response
+ edu = Edu(**edu_dict)
+ await self.registry.on_edu(edu.edu_type, origin, edu.content)
- async def received_edu(self, origin, edu_type, content):
- received_edus_counter.inc()
- await self.registry.on_edu(edu_type, origin, content)
+ await concurrently_execute(
+ _process_edu,
+ getattr(transaction, "edus", []),
+ TRANSACTION_CONCURRENCY_LIMIT,
+ )
async def on_context_state_request(self, origin, room_id, event_id):
origin_host, _ = parse_server_name(origin)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index ced4925a98..174f6e42be 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -259,7 +259,9 @@ class FederationRemoteSendQueue(object):
def federation_ack(self, token):
self._clear_queue_before_pos(token)
- def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
+ async def get_replication_rows(
+ self, from_token, to_token, limit, federation_ack=None
+ ):
"""Get rows to be sent over federation between the two tokens
Args:
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4ebb0e8bc0..36c83c3027 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -21,6 +21,7 @@ from prometheus_client import Counter
from twisted.internet import defer
+import synapse
import synapse.metrics
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager
@@ -54,7 +55,7 @@ sent_pdus_destination_dist_total = Counter(
class FederationSender(object):
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
self.server_name = hs.hostname
@@ -482,7 +483,20 @@ class FederationSender(object):
def send_device_messages(self, destination):
if destination == self.server_name:
- logger.info("Not sending device update to ourselves")
+ logger.warning("Not sending device update to ourselves")
+ return
+
+ self._get_per_destination_queue(destination).attempt_new_transaction()
+
+ def wake_destination(self, destination: str):
+ """Called when we want to retry sending transactions to a remote.
+
+ This is mainly useful if the remote server has been down and we think it
+ might have come back.
+ """
+
+ if destination == self.server_name:
+ logger.warning("Not waking up ourselves")
return
self._get_per_destination_queue(destination).attempt_new_transaction()
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index a5b36b1827..5012aaea35 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -31,6 +31,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
+from synapse.types import StateMap
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
# This is defined in the Matrix spec and enforced by the receiver.
@@ -77,7 +78,7 @@ class PerDestinationQueue(object):
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu
- self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu]
+ self._pending_edus_keyed = {} # type: StateMap[Edu]
# Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b4cbf23394..d8cf9ed299 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -44,6 +44,7 @@ from synapse.logging.opentracing import (
tags,
whitelisted_homeserver,
)
+from synapse.server import HomeServer
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
@@ -101,12 +102,17 @@ class NoAuthenticationError(AuthenticationError):
class Authenticator(object):
- def __init__(self, hs):
+ def __init__(self, hs: HomeServer):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+ self.notifer = hs.get_notifier()
+
+ self.replication_client = None
+ if hs.config.worker.worker_app:
+ self.replication_client = hs.get_tcp_replication()
# A method just so we can pass 'self' as the authenticator to the Servlets
async def authenticate_request(self, request, content):
@@ -166,6 +172,17 @@ class Authenticator(object):
try:
logger.info("Marking origin %r as up", origin)
await self.store.set_destination_retry_timings(origin, None, 0, 0)
+
+ # Inform the relevant places that the remote server is back up.
+ self.notifer.notify_remote_server_up(origin)
+ if self.replication_client:
+ # If we're on a worker we try and inform master about this. The
+ # replication client doesn't hook into the notifier to avoid
+ # infinite loops where we send a `REMOTE_SERVER_UP` command to
+ # master, which then echoes it back to us which in turn pokes
+ # the notifier.
+ self.replication_client.send_remote_server_up(origin)
+
except Exception:
logger.exception("Error resetting retry timings on %s", origin)
|