diff options
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r-- | synapse/federation/federation_client.py | 60 |
1 files changed, 59 insertions, 1 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c2997997da..2121e92e3a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -56,7 +56,7 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.transport.client import SendJoinResponse -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -1610,6 +1610,64 @@ class FederationClient(FederationBase): except ValueError as e: raise InvalidResponseError(str(e)) + async def get_account_status( + self, destination: str, user_ids: List[str] + ) -> Tuple[JsonDict, List[str]]: + """Retrieves account statuses for a given list of users on a given remote + homeserver. + + If the request fails for any reason, all user IDs for this destination are marked + as failed. + + Args: + destination: the destination to contact + user_ids: the user ID(s) for which to request account status(es) + + Returns: + The account statuses, as well as the list of user IDs for which it was not + possible to retrieve a status. + """ + try: + res = await self.transport_layer.get_account_status(destination, user_ids) + except Exception: + # If the query failed for any reason, mark all the users as failed. + return {}, user_ids + + statuses = res.get("account_statuses", {}) + failures = res.get("failures", []) + + if not isinstance(statuses, dict) or not isinstance(failures, list): + # Make sure we're not feeding back malformed data back to the caller. + logger.warning( + "Destination %s responded with malformed data to account_status query", + destination, + ) + return {}, user_ids + + for user_id in user_ids: + # Any account whose status is missing is a user we failed to receive the + # status of. + if user_id not in statuses and user_id not in failures: + failures.append(user_id) + + # Filter out any user ID that doesn't belong to the remote server that sent its + # status (or failure). + def filter_user_id(user_id: str) -> bool: + try: + return UserID.from_string(user_id).domain == destination + except SynapseError: + # If the user ID doesn't parse, ignore it. + return False + + filtered_statuses = dict( + # item is a (key, value) tuple, so item[0] is the user ID. + filter(lambda item: filter_user_id(item[0]), statuses.items()) + ) + + filtered_failures = list(filter(filter_user_id, failures)) + + return filtered_statuses, filtered_failures + @attr.s(frozen=True, slots=True, auto_attribs=True) class TimestampToEventResponse: |