summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py23
-rw-r--r--synapse/federation/federation_server.py3
-rw-r--r--synapse/federation/persistence.py3
-rw-r--r--synapse/federation/sender/per_destination_queue.py12
-rw-r--r--synapse/federation/sender/transaction_manager.py12
-rw-r--r--synapse/federation/transport/client.py48
-rw-r--r--synapse/federation/transport/server/__init__.py16
-rw-r--r--synapse/federation/transport/server/_base.py81
-rw-r--r--synapse/federation/transport/server/federation.py39
-rw-r--r--synapse/federation/transport/server/groups_local.py8
-rw-r--r--synapse/federation/transport/server/groups_server.py8
11 files changed, 135 insertions, 118 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 6ea4edfc71..74f17aa4da 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -56,7 +56,6 @@ from synapse.api.room_versions import (
 from synapse.events import EventBase, builder
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.transport.client import SendJoinResponse
-from synapse.logging.utils import log_function
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -119,7 +118,8 @@ class FederationClient(FederationBase):
         # It is a map of (room ID, suggested-only) -> the response of
         # get_room_hierarchy.
         self._get_room_hierarchy_cache: ExpiringCache[
-            Tuple[str, bool], Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]
+            Tuple[str, bool],
+            Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]],
         ] = ExpiringCache(
             cache_name="get_room_hierarchy_cache",
             clock=self._clock,
@@ -144,7 +144,6 @@ class FederationClient(FederationBase):
             if destination_dict:
                 self.pdu_destination_tried[event_id] = destination_dict
 
-    @log_function
     async def make_query(
         self,
         destination: str,
@@ -178,7 +177,6 @@ class FederationClient(FederationBase):
             ignore_backoff=ignore_backoff,
         )
 
-    @log_function
     async def query_client_keys(
         self, destination: str, content: JsonDict, timeout: int
     ) -> JsonDict:
@@ -196,7 +194,6 @@ class FederationClient(FederationBase):
             destination, content, timeout
         )
 
-    @log_function
     async def query_user_devices(
         self, destination: str, user_id: str, timeout: int = 30000
     ) -> JsonDict:
@@ -208,7 +205,6 @@ class FederationClient(FederationBase):
             destination, user_id, timeout
         )
 
-    @log_function
     async def claim_client_keys(
         self, destination: str, content: JsonDict, timeout: int
     ) -> JsonDict:
@@ -1338,7 +1334,7 @@ class FederationClient(FederationBase):
         destinations: Iterable[str],
         room_id: str,
         suggested_only: bool,
-    ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
+    ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]:
         """
         Call other servers to get a hierarchy of the given room.
 
@@ -1353,7 +1349,8 @@ class FederationClient(FederationBase):
 
         Returns:
             A tuple of:
-                The room as a JSON dictionary.
+                The room as a JSON dictionary, without a "children_state" key.
+                A list of `m.space.child` state events.
                 A list of children rooms, as JSON dictionaries.
                 A list of inaccessible children room IDs.
 
@@ -1368,7 +1365,7 @@ class FederationClient(FederationBase):
 
         async def send_request(
             destination: str,
-        ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
+        ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]:
             try:
                 res = await self.transport_layer.get_room_hierarchy(
                     destination=destination,
@@ -1397,7 +1394,7 @@ class FederationClient(FederationBase):
                 raise InvalidResponseError("'room' must be a dict")
 
             # Validate children_state of the room.
-            children_state = room.get("children_state", [])
+            children_state = room.pop("children_state", [])
             if not isinstance(children_state, Sequence):
                 raise InvalidResponseError("'room.children_state' must be a list")
             if any(not isinstance(e, dict) for e in children_state):
@@ -1426,7 +1423,7 @@ class FederationClient(FederationBase):
                     "Invalid room ID in 'inaccessible_children' list"
                 )
 
-            return room, children, inaccessible_children
+            return room, children_state, children, inaccessible_children
 
         try:
             result = await self._try_destination_list(
@@ -1474,8 +1471,6 @@ class FederationClient(FederationBase):
                 if event.room_id == room_id:
                     children_events.append(event.data)
                     children_room_ids.add(event.state_key)
-            # And add them under the requested room.
-            requested_room["children_state"] = children_events
 
             # Find the children rooms.
             children = []
@@ -1485,7 +1480,7 @@ class FederationClient(FederationBase):
 
             # It isn't clear from the response whether some of the rooms are
             # not accessible.
-            result = (requested_room, children, ())
+            result = (requested_room, children_events, children, ())
 
         # Cache the result to avoid fetching data over federation every time.
         self._get_room_hierarchy_cache[(room_id, suggested_only)] = result
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ee71f289c8..af9cb98f67 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -58,7 +58,6 @@ from synapse.logging.context import (
     run_in_background,
 )
 from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
-from synapse.logging.utils import log_function
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.replication.http.federation import (
     ReplicationFederationSendEduRestServlet,
@@ -859,7 +858,6 @@ class FederationServer(FederationBase):
             res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
         return 200, res
 
-    @log_function
     async def on_query_client_keys(
         self, origin: str, content: Dict[str, str]
     ) -> Tuple[int, Dict[str, Any]]:
@@ -940,7 +938,6 @@ class FederationServer(FederationBase):
 
         return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
 
-    @log_function
     async def on_openid_userinfo(self, token: str) -> Optional[str]:
         ts_now_ms = self._clock.time_msec()
         return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 523ab1c51e..60e2e6cf01 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -23,7 +23,6 @@ import logging
 from typing import Optional, Tuple
 
 from synapse.federation.units import Transaction
-from synapse.logging.utils import log_function
 from synapse.storage.databases.main import DataStore
 from synapse.types import JsonDict
 
@@ -36,7 +35,6 @@ class TransactionActions:
     def __init__(self, datastore: DataStore):
         self.store = datastore
 
-    @log_function
     async def have_responded(
         self, origin: str, transaction: Transaction
     ) -> Optional[Tuple[int, JsonDict]]:
@@ -53,7 +51,6 @@ class TransactionActions:
 
         return await self.store.get_received_txn_response(transaction_id, origin)
 
-    @log_function
     async def set_response(
         self, origin: str, transaction: Transaction, code: int, response: JsonDict
     ) -> None:
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 391b30fbb5..8152e80b88 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -607,18 +607,18 @@ class PerDestinationQueue:
         self._pending_pdus = []
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _TransactionQueueManager:
     """A helper async context manager for pulling stuff off the queues and
     tracking what was last successfully sent, etc.
     """
 
-    queue = attr.ib(type=PerDestinationQueue)
+    queue: PerDestinationQueue
 
-    _device_stream_id = attr.ib(type=Optional[int], default=None)
-    _device_list_id = attr.ib(type=Optional[int], default=None)
-    _last_stream_ordering = attr.ib(type=Optional[int], default=None)
-    _pdus = attr.ib(type=List[EventBase], factory=list)
+    _device_stream_id: Optional[int] = None
+    _device_list_id: Optional[int] = None
+    _last_stream_ordering: Optional[int] = None
+    _pdus: List[EventBase] = attr.Factory(list)
 
     async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
         # First we calculate the EDUs we want to send, if any.
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index ab935e5a7e..742ee57255 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -35,6 +35,7 @@ if TYPE_CHECKING:
     import synapse.server
 
 logger = logging.getLogger(__name__)
+issue_8631_logger = logging.getLogger("synapse.8631_debug")
 
 last_pdu_ts_metric = Gauge(
     "synapse_federation_last_sent_pdu_time",
@@ -124,6 +125,17 @@ class TransactionManager:
                 len(pdus),
                 len(edus),
             )
+            if issue_8631_logger.isEnabledFor(logging.DEBUG):
+                DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"}
+                device_list_updates = [
+                    edu.content for edu in edus if edu.edu_type in DEVICE_UPDATE_EDUS
+                ]
+                if device_list_updates:
+                    issue_8631_logger.debug(
+                        "about to send txn [%s] including device list updates: %s",
+                        transaction.transaction_id,
+                        device_list_updates,
+                    )
 
             # Actually send the transaction
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9fc4c31c93..8782586cd6 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -44,7 +44,6 @@ from synapse.api.urls import (
 from synapse.events import EventBase, make_event_from_dict
 from synapse.federation.units import Transaction
 from synapse.http.matrixfederationclient import ByteParser
-from synapse.logging.utils import log_function
 from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
@@ -62,7 +61,6 @@ class TransportLayerClient:
         self.server_name = hs.hostname
         self.client = hs.get_federation_http_client()
 
-    @log_function
     async def get_room_state_ids(
         self, destination: str, room_id: str, event_id: str
     ) -> JsonDict:
@@ -88,7 +86,6 @@ class TransportLayerClient:
             try_trailing_slash_on_400=True,
         )
 
-    @log_function
     async def get_event(
         self, destination: str, event_id: str, timeout: Optional[int] = None
     ) -> JsonDict:
@@ -111,7 +108,6 @@ class TransportLayerClient:
             destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
         )
 
-    @log_function
     async def backfill(
         self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
     ) -> Optional[JsonDict]:
@@ -149,7 +145,6 @@ class TransportLayerClient:
             destination, path=path, args=args, try_trailing_slash_on_400=True
         )
 
-    @log_function
     async def timestamp_to_event(
         self, destination: str, room_id: str, timestamp: int, direction: str
     ) -> Union[JsonDict, List]:
@@ -185,7 +180,6 @@ class TransportLayerClient:
 
         return remote_response
 
-    @log_function
     async def send_transaction(
         self,
         transaction: Transaction,
@@ -234,7 +228,6 @@ class TransportLayerClient:
             try_trailing_slash_on_400=True,
         )
 
-    @log_function
     async def make_query(
         self,
         destination: str,
@@ -254,7 +247,6 @@ class TransportLayerClient:
             ignore_backoff=ignore_backoff,
         )
 
-    @log_function
     async def make_membership_event(
         self,
         destination: str,
@@ -317,7 +309,6 @@ class TransportLayerClient:
             ignore_backoff=ignore_backoff,
         )
 
-    @log_function
     async def send_join_v1(
         self,
         room_version: RoomVersion,
@@ -336,7 +327,6 @@ class TransportLayerClient:
             max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
         )
 
-    @log_function
     async def send_join_v2(
         self,
         room_version: RoomVersion,
@@ -355,7 +345,6 @@ class TransportLayerClient:
             max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
         )
 
-    @log_function
     async def send_leave_v1(
         self, destination: str, room_id: str, event_id: str, content: JsonDict
     ) -> Tuple[int, JsonDict]:
@@ -372,7 +361,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def send_leave_v2(
         self, destination: str, room_id: str, event_id: str, content: JsonDict
     ) -> JsonDict:
@@ -389,7 +377,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def send_knock_v1(
         self,
         destination: str,
@@ -423,7 +410,6 @@ class TransportLayerClient:
             destination=destination, path=path, data=content
         )
 
-    @log_function
     async def send_invite_v1(
         self, destination: str, room_id: str, event_id: str, content: JsonDict
     ) -> Tuple[int, JsonDict]:
@@ -433,7 +419,6 @@ class TransportLayerClient:
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
-    @log_function
     async def send_invite_v2(
         self, destination: str, room_id: str, event_id: str, content: JsonDict
     ) -> JsonDict:
@@ -443,7 +428,6 @@ class TransportLayerClient:
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
-    @log_function
     async def get_public_rooms(
         self,
         remote_server: str,
@@ -516,7 +500,6 @@ class TransportLayerClient:
 
         return response
 
-    @log_function
     async def exchange_third_party_invite(
         self, destination: str, room_id: str, event_dict: JsonDict
     ) -> JsonDict:
@@ -526,7 +509,6 @@ class TransportLayerClient:
             destination=destination, path=path, data=event_dict
         )
 
-    @log_function
     async def get_event_auth(
         self, destination: str, room_id: str, event_id: str
     ) -> JsonDict:
@@ -534,7 +516,6 @@ class TransportLayerClient:
 
         return await self.client.get_json(destination=destination, path=path)
 
-    @log_function
     async def query_client_keys(
         self, destination: str, query_content: JsonDict, timeout: int
     ) -> JsonDict:
@@ -576,7 +557,6 @@ class TransportLayerClient:
             destination=destination, path=path, data=query_content, timeout=timeout
         )
 
-    @log_function
     async def query_user_devices(
         self, destination: str, user_id: str, timeout: int
     ) -> JsonDict:
@@ -616,7 +596,6 @@ class TransportLayerClient:
             destination=destination, path=path, timeout=timeout
         )
 
-    @log_function
     async def claim_client_keys(
         self, destination: str, query_content: JsonDict, timeout: int
     ) -> JsonDict:
@@ -655,7 +634,6 @@ class TransportLayerClient:
             destination=destination, path=path, data=query_content, timeout=timeout
         )
 
-    @log_function
     async def get_missing_events(
         self,
         destination: str,
@@ -680,7 +658,6 @@ class TransportLayerClient:
             timeout=timeout,
         )
 
-    @log_function
     async def get_group_profile(
         self, destination: str, group_id: str, requester_user_id: str
     ) -> JsonDict:
@@ -694,7 +671,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def update_group_profile(
         self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
     ) -> JsonDict:
@@ -716,7 +692,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def get_group_summary(
         self, destination: str, group_id: str, requester_user_id: str
     ) -> JsonDict:
@@ -730,7 +705,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def get_rooms_in_group(
         self, destination: str, group_id: str, requester_user_id: str
     ) -> JsonDict:
@@ -798,7 +772,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def get_users_in_group(
         self, destination: str, group_id: str, requester_user_id: str
     ) -> JsonDict:
@@ -812,7 +785,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def get_invited_users_in_group(
         self, destination: str, group_id: str, requester_user_id: str
     ) -> JsonDict:
@@ -826,7 +798,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def accept_group_invite(
         self, destination: str, group_id: str, user_id: str, content: JsonDict
     ) -> JsonDict:
@@ -837,7 +808,6 @@ class TransportLayerClient:
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
-    @log_function
     def join_group(
         self, destination: str, group_id: str, user_id: str, content: JsonDict
     ) -> Awaitable[JsonDict]:
@@ -848,7 +818,6 @@ class TransportLayerClient:
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
-    @log_function
     async def invite_to_group(
         self,
         destination: str,
@@ -868,7 +837,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def invite_to_group_notification(
         self, destination: str, group_id: str, user_id: str, content: JsonDict
     ) -> JsonDict:
@@ -882,7 +850,6 @@ class TransportLayerClient:
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
-    @log_function
     async def remove_user_from_group(
         self,
         destination: str,
@@ -902,7 +869,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def remove_user_from_group_notification(
         self, destination: str, group_id: str, user_id: str, content: JsonDict
     ) -> JsonDict:
@@ -916,7 +882,6 @@ class TransportLayerClient:
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
-    @log_function
     async def renew_group_attestation(
         self, destination: str, group_id: str, user_id: str, content: JsonDict
     ) -> JsonDict:
@@ -930,7 +895,6 @@ class TransportLayerClient:
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
-    @log_function
     async def update_group_summary_room(
         self,
         destination: str,
@@ -959,7 +923,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def delete_group_summary_room(
         self,
         destination: str,
@@ -986,7 +949,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def get_group_categories(
         self, destination: str, group_id: str, requester_user_id: str
     ) -> JsonDict:
@@ -1000,7 +962,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def get_group_category(
         self, destination: str, group_id: str, requester_user_id: str, category_id: str
     ) -> JsonDict:
@@ -1014,7 +975,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def update_group_category(
         self,
         destination: str,
@@ -1034,7 +994,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def delete_group_category(
         self, destination: str, group_id: str, requester_user_id: str, category_id: str
     ) -> JsonDict:
@@ -1048,7 +1007,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def get_group_roles(
         self, destination: str, group_id: str, requester_user_id: str
     ) -> JsonDict:
@@ -1062,7 +1020,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def get_group_role(
         self, destination: str, group_id: str, requester_user_id: str, role_id: str
     ) -> JsonDict:
@@ -1076,7 +1033,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def update_group_role(
         self,
         destination: str,
@@ -1096,7 +1052,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def delete_group_role(
         self, destination: str, group_id: str, requester_user_id: str, role_id: str
     ) -> JsonDict:
@@ -1110,7 +1065,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def update_group_summary_user(
         self,
         destination: str,
@@ -1136,7 +1090,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def set_group_join_policy(
         self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
     ) -> JsonDict:
@@ -1151,7 +1104,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    @log_function
     async def delete_group_summary_user(
         self,
         destination: str,
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index 77b936361a..db4fe2c798 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Dict, Iterable, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type
 
 from typing_extensions import Literal
 
@@ -36,17 +36,19 @@ from synapse.http.servlet import (
     parse_integer_from_args,
     parse_string_from_args,
 )
-from synapse.server import HomeServer
 from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util.ratelimitutils import FederationRateLimiter
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class TransportLayerServer(JsonResource):
     """Handles incoming federation HTTP requests"""
 
-    def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None):
+    def __init__(self, hs: "HomeServer", servlet_groups: Optional[List[str]] = None):
         """Initialize the TransportLayerServer
 
         Will by default register all servlets. For custom behaviour, pass in
@@ -113,7 +115,7 @@ class PublicRoomList(BaseFederationServlet):
 
     def __init__(
         self,
-        hs: HomeServer,
+        hs: "HomeServer",
         authenticator: Authenticator,
         ratelimiter: FederationRateLimiter,
         server_name: str,
@@ -203,7 +205,7 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
 
     def __init__(
         self,
-        hs: HomeServer,
+        hs: "HomeServer",
         authenticator: Authenticator,
         ratelimiter: FederationRateLimiter,
         server_name: str,
@@ -251,7 +253,7 @@ class OpenIdUserInfo(BaseFederationServlet):
 
     def __init__(
         self,
-        hs: HomeServer,
+        hs: "HomeServer",
         authenticator: Authenticator,
         ratelimiter: FederationRateLimiter,
         server_name: str,
@@ -297,7 +299,7 @@ DEFAULT_SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = {
 
 
 def register_servlets(
-    hs: HomeServer,
+    hs: "HomeServer",
     resource: HttpServer,
     authenticator: Authenticator,
     ratelimiter: FederationRateLimiter,
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index da1fbf8b63..dff2b68359 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -15,7 +15,8 @@
 import functools
 import logging
 import re
-from typing import Any, Awaitable, Callable, Optional, Tuple, cast
+import time
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, cast
 
 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
 from synapse.api.urls import FEDERATION_V1_PREFIX
@@ -24,16 +25,20 @@ from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
+    active_span,
     set_tag,
     span_context_from_request,
+    start_active_span,
     start_active_span_follows_from,
     whitelisted_homeserver,
 )
-from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.stringutils import parse_and_validate_server_name
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -46,7 +51,7 @@ class NoAuthenticationError(AuthenticationError):
 
 
 class Authenticator:
-    def __init__(self, hs: HomeServer):
+    def __init__(self, hs: "HomeServer"):
         self._clock = hs.get_clock()
         self.keyring = hs.get_keyring()
         self.server_name = hs.hostname
@@ -114,11 +119,11 @@ class Authenticator:
         # alive
         retry_timings = await self.store.get_destination_retry_timings(origin)
         if retry_timings and retry_timings.retry_last_ts:
-            run_in_background(self._reset_retry_timings, origin)
+            run_in_background(self.reset_retry_timings, origin)
 
         return origin
 
-    async def _reset_retry_timings(self, origin: str) -> None:
+    async def reset_retry_timings(self, origin: str) -> None:
         try:
             logger.info("Marking origin %r as up", origin)
             await self.store.set_destination_retry_timings(origin, None, 0, 0)
@@ -227,7 +232,7 @@ class BaseFederationServlet:
 
     def __init__(
         self,
-        hs: HomeServer,
+        hs: "HomeServer",
         authenticator: Authenticator,
         ratelimiter: FederationRateLimiter,
         server_name: str,
@@ -263,9 +268,10 @@ class BaseFederationServlet:
                 content = parse_json_object_from_request(request)
 
             try:
-                origin: Optional[str] = await authenticator.authenticate_request(
-                    request, content
-                )
+                with start_active_span("authenticate_request"):
+                    origin: Optional[str] = await authenticator.authenticate_request(
+                        request, content
+                    )
             except NoAuthenticationError:
                 origin = None
                 if self.REQUIRE_AUTH:
@@ -280,32 +286,57 @@ class BaseFederationServlet:
             # update the active opentracing span with the authenticated entity
             set_tag("authenticated_entity", origin)
 
-            # if the origin is authenticated and whitelisted, link to its span context
+            # if the origin is authenticated and whitelisted, use its span context
+            # as the parent.
             context = None
             if origin and whitelisted_homeserver(origin):
                 context = span_context_from_request(request)
 
-            scope = start_active_span_follows_from(
-                "incoming-federation-request", contexts=(context,) if context else ()
-            )
+            if context:
+                servlet_span = active_span()
+                # a scope which uses the origin's context as a parent
+                processing_start_time = time.time()
+                scope = start_active_span_follows_from(
+                    "incoming-federation-request",
+                    child_of=context,
+                    contexts=(servlet_span,),
+                    start_time=processing_start_time,
+                )
 
-            with scope:
-                if origin and self.RATELIMIT:
-                    with ratelimiter.ratelimit(origin) as d:
-                        await d
-                        if request._disconnected:
-                            logger.warning(
-                                "client disconnected before we started processing "
-                                "request"
+            else:
+                # just use our context as a parent
+                scope = start_active_span(
+                    "incoming-federation-request",
+                )
+
+            try:
+                with scope:
+                    if origin and self.RATELIMIT:
+                        with ratelimiter.ratelimit(origin) as d:
+                            await d
+                            if request._disconnected:
+                                logger.warning(
+                                    "client disconnected before we started processing "
+                                    "request"
+                                )
+                                return None
+                            response = await func(
+                                origin, content, request.args, *args, **kwargs
                             )
-                            return None
+                    else:
                         response = await func(
                             origin, content, request.args, *args, **kwargs
                         )
-                else:
-                    response = await func(
-                        origin, content, request.args, *args, **kwargs
+            finally:
+                # if we used the origin's context as the parent, add a new span using
+                # the servlet span as a parent, so that we have a link
+                if context:
+                    scope2 = start_active_span_follows_from(
+                        "process-federation_request",
+                        contexts=(scope.span,),
+                        start_time=processing_start_time,
                     )
+                    scope2.close()
 
             return response
 
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 77bfd88ad0..d86dfede4e 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -12,7 +12,17 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 import logging
-from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import (
+    TYPE_CHECKING,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+    Union,
+)
 
 from typing_extensions import Literal
 
@@ -30,12 +40,15 @@ from synapse.http.servlet import (
     parse_string_from_args,
     parse_strings_from_args,
 )
-from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
+issue_8631_logger = logging.getLogger("synapse.8631_debug")
 
 
 class BaseFederationServerServlet(BaseFederationServlet):
@@ -46,7 +59,7 @@ class BaseFederationServerServlet(BaseFederationServlet):
 
     def __init__(
         self,
-        hs: HomeServer,
+        hs: "HomeServer",
         authenticator: Authenticator,
         ratelimiter: FederationRateLimiter,
         server_name: str,
@@ -95,6 +108,20 @@ class FederationSendServlet(BaseFederationServerServlet):
                 len(transaction_data.get("edus", [])),
             )
 
+            if issue_8631_logger.isEnabledFor(logging.DEBUG):
+                DEVICE_UPDATE_EDUS = ["m.device_list_update", "m.signing_key_update"]
+                device_list_updates = [
+                    edu.content
+                    for edu in transaction_data.get("edus", [])
+                    if edu.get("edu_type") in DEVICE_UPDATE_EDUS
+                ]
+                if device_list_updates:
+                    issue_8631_logger.debug(
+                        "received transaction [%s] including device list updates: %s",
+                        transaction_id,
+                        device_list_updates,
+                    )
+
         except Exception as e:
             logger.exception(e)
             return 400, {"error": "Invalid transaction"}
@@ -581,7 +608,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
 
     def __init__(
         self,
-        hs: HomeServer,
+        hs: "HomeServer",
         authenticator: Authenticator,
         ratelimiter: FederationRateLimiter,
         server_name: str,
@@ -655,7 +682,7 @@ class FederationRoomHierarchyServlet(BaseFederationServlet):
 
     def __init__(
         self,
-        hs: HomeServer,
+        hs: "HomeServer",
         authenticator: Authenticator,
         ratelimiter: FederationRateLimiter,
         server_name: str,
@@ -691,7 +718,7 @@ class RoomComplexityServlet(BaseFederationServlet):
 
     def __init__(
         self,
-        hs: HomeServer,
+        hs: "HomeServer",
         authenticator: Authenticator,
         ratelimiter: FederationRateLimiter,
         server_name: str,
diff --git a/synapse/federation/transport/server/groups_local.py b/synapse/federation/transport/server/groups_local.py
index a12cd18d58..496472e1dc 100644
--- a/synapse/federation/transport/server/groups_local.py
+++ b/synapse/federation/transport/server/groups_local.py
@@ -11,7 +11,7 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
-from typing import Dict, List, Tuple, Type
+from typing import TYPE_CHECKING, Dict, List, Tuple, Type
 
 from synapse.api.errors import SynapseError
 from synapse.federation.transport.server._base import (
@@ -19,10 +19,12 @@ from synapse.federation.transport.server._base import (
     BaseFederationServlet,
 )
 from synapse.handlers.groups_local import GroupsLocalHandler
-from synapse.server import HomeServer
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.ratelimitutils import FederationRateLimiter
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class BaseGroupsLocalServlet(BaseFederationServlet):
     """Abstract base class for federation servlet classes which provides a groups local handler.
@@ -32,7 +34,7 @@ class BaseGroupsLocalServlet(BaseFederationServlet):
 
     def __init__(
         self,
-        hs: HomeServer,
+        hs: "HomeServer",
         authenticator: Authenticator,
         ratelimiter: FederationRateLimiter,
         server_name: str,
diff --git a/synapse/federation/transport/server/groups_server.py b/synapse/federation/transport/server/groups_server.py
index b30e92a5eb..851b50152e 100644
--- a/synapse/federation/transport/server/groups_server.py
+++ b/synapse/federation/transport/server/groups_server.py
@@ -11,7 +11,7 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
-from typing import Dict, List, Tuple, Type
+from typing import TYPE_CHECKING, Dict, List, Tuple, Type
 
 from typing_extensions import Literal
 
@@ -22,10 +22,12 @@ from synapse.federation.transport.server._base import (
     BaseFederationServlet,
 )
 from synapse.http.servlet import parse_string_from_args
-from synapse.server import HomeServer
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.ratelimitutils import FederationRateLimiter
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class BaseGroupsServerServlet(BaseFederationServlet):
     """Abstract base class for federation servlet classes which provides a groups server handler.
@@ -35,7 +37,7 @@ class BaseGroupsServerServlet(BaseFederationServlet):
 
     def __init__(
         self,
-        hs: HomeServer,
+        hs: "HomeServer",
         authenticator: Authenticator,
         ratelimiter: FederationRateLimiter,
         server_name: str,