From 4fb92d93eae3c3519a9a47ea7fdef2d492014f2f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 26 Jul 2021 11:53:09 -0400 Subject: Add type hints to synapse.federation.transport.client. (#10408) --- changelog.d/10408.misc | 1 + synapse/federation/transport/client.py | 499 ++++++++++++++++++++------------- 2 files changed, 299 insertions(+), 201 deletions(-) create mode 100644 changelog.d/10408.misc diff --git a/changelog.d/10408.misc b/changelog.d/10408.misc new file mode 100644 index 0000000000..abccd210a9 --- /dev/null +++ b/changelog.d/10408.misc @@ -0,0 +1 @@ +Add type hints to `synapse.federation.transport.client` module. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 98b1bf77fd..e73bdb52b3 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,7 +15,7 @@ import logging import urllib -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union import attr import ijson @@ -29,6 +29,7 @@ from synapse.api.urls import ( FEDERATION_V2_PREFIX, ) from synapse.events import EventBase, make_event_from_dict +from synapse.federation.units import Transaction from synapse.http.matrixfederationclient import ByteParser from synapse.logging.utils import log_function from synapse.types import JsonDict @@ -49,23 +50,25 @@ class TransportLayerClient: self.client = hs.get_federation_http_client() @log_function - def get_room_state_ids(self, destination, room_id, event_id): + async def get_room_state_ids( + self, destination: str, room_id: str, event_id: str + ) -> JsonDict: """Requests all state for a given room from the given server at the given event. Returns the state's event_id's Args: - destination (str): The host name of the remote homeserver we want + destination: The host name of the remote homeserver we want to get the state from. - context (str): The name of the context we want the state of - event_id (str): The event we want the context at. + context: The name of the context we want the state of + event_id: The event we want the context at. Returns: - Awaitable: Results in a dict received from the remote homeserver. + Results in a dict received from the remote homeserver. """ logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) path = _create_v1_path("/state_ids/%s", room_id) - return self.client.get_json( + return await self.client.get_json( destination, path=path, args={"event_id": event_id}, @@ -73,39 +76,43 @@ class TransportLayerClient: ) @log_function - def get_event(self, destination, event_id, timeout=None): + async def get_event( + self, destination: str, event_id: str, timeout: Optional[int] = None + ) -> JsonDict: """Requests the pdu with give id and origin from the given server. Args: - destination (str): The host name of the remote homeserver we want + destination: The host name of the remote homeserver we want to get the state from. - event_id (str): The id of the event being requested. - timeout (int): How long to try (in ms) the destination for before + event_id: The id of the event being requested. + timeout: How long to try (in ms) the destination for before giving up. None indicates no timeout. Returns: - Awaitable: Results in a dict received from the remote homeserver. + Results in a dict received from the remote homeserver. """ logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) path = _create_v1_path("/event/%s", event_id) - return self.client.get_json( + return await self.client.get_json( destination, path=path, timeout=timeout, try_trailing_slash_on_400=True ) @log_function - def backfill(self, destination, room_id, event_tuples, limit): + async def backfill( + self, destination: str, room_id: str, event_tuples: Iterable[str], limit: int + ) -> Optional[JsonDict]: """Requests `limit` previous PDUs in a given context before list of PDUs. Args: - dest (str) - room_id (str) - event_tuples (list) - limit (int) + destination + room_id + event_tuples + limit Returns: - Awaitable: Results in a dict received from the remote homeserver. + Results in a dict received from the remote homeserver. """ logger.debug( "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s", @@ -117,18 +124,22 @@ class TransportLayerClient: if not event_tuples: # TODO: raise? - return + return None path = _create_v1_path("/backfill/%s", room_id) args = {"v": event_tuples, "limit": [str(limit)]} - return self.client.get_json( + return await self.client.get_json( destination, path=path, args=args, try_trailing_slash_on_400=True ) @log_function - async def send_transaction(self, transaction, json_data_callback=None): + async def send_transaction( + self, + transaction: Transaction, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + ) -> JsonDict: """Sends the given Transaction to its destination Args: @@ -149,21 +160,21 @@ class TransportLayerClient: """ logger.debug( "send_data dest=%s, txid=%s", - transaction.destination, - transaction.transaction_id, + transaction.destination, # type: ignore + transaction.transaction_id, # type: ignore ) - if transaction.destination == self.server_name: + if transaction.destination == self.server_name: # type: ignore raise RuntimeError("Transport layer cannot send to itself!") # FIXME: This is only used by the tests. The actual json sent is # generated by the json_data_callback. json_data = transaction.get_dict() - path = _create_v1_path("/send/%s", transaction.transaction_id) + path = _create_v1_path("/send/%s", transaction.transaction_id) # type: ignore - response = await self.client.put_json( - transaction.destination, + return await self.client.put_json( + transaction.destination, # type: ignore path=path, data=json_data, json_data_callback=json_data_callback, @@ -172,8 +183,6 @@ class TransportLayerClient: try_trailing_slash_on_400=True, ) - return response - @log_function async def make_query( self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False @@ -193,8 +202,13 @@ class TransportLayerClient: @log_function async def make_membership_event( - self, destination, room_id, user_id, membership, params - ): + self, + destination: str, + room_id: str, + user_id: str, + membership: str, + params: Optional[Mapping[str, Union[str, Iterable[str]]]], + ) -> JsonDict: """Asks a remote server to build and sign us a membership event Note that this does not append any events to any graphs. @@ -240,7 +254,7 @@ class TransportLayerClient: ignore_backoff = True retry_on_dns_fail = True - content = await self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args=params, @@ -249,20 +263,18 @@ class TransportLayerClient: ignore_backoff=ignore_backoff, ) - return content - @log_function async def send_join_v1( self, - room_version, - destination, - room_id, - event_id, - content, + room_version: RoomVersion, + destination: str, + room_id: str, + event_id: str, + content: JsonDict, ) -> "SendJoinResponse": path = _create_v1_path("/send_join/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, @@ -270,15 +282,18 @@ class TransportLayerClient: max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) - return response - @log_function async def send_join_v2( - self, room_version, destination, room_id, event_id, content + self, + room_version: RoomVersion, + destination: str, + room_id: str, + event_id: str, + content: JsonDict, ) -> "SendJoinResponse": path = _create_v2_path("/send_join/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, @@ -286,13 +301,13 @@ class TransportLayerClient: max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) - return response - @log_function - async def send_leave_v1(self, destination, room_id, event_id, content): + async def send_leave_v1( + self, destination: str, room_id: str, event_id: str, content: JsonDict + ) -> Tuple[int, JsonDict]: path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, @@ -303,13 +318,13 @@ class TransportLayerClient: ignore_backoff=True, ) - return response - @log_function - async def send_leave_v2(self, destination, room_id, event_id, content): + async def send_leave_v2( + self, destination: str, room_id: str, event_id: str, content: JsonDict + ) -> JsonDict: path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, @@ -320,8 +335,6 @@ class TransportLayerClient: ignore_backoff=True, ) - return response - @log_function async def send_knock_v1( self, @@ -357,25 +370,25 @@ class TransportLayerClient: ) @log_function - async def send_invite_v1(self, destination, room_id, event_id, content): + async def send_invite_v1( + self, destination: str, room_id: str, event_id: str, content: JsonDict + ) -> Tuple[int, JsonDict]: path = _create_v1_path("/invite/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) - return response - @log_function - async def send_invite_v2(self, destination, room_id, event_id, content): + async def send_invite_v2( + self, destination: str, room_id: str, event_id: str, content: JsonDict + ) -> JsonDict: path = _create_v2_path("/invite/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) - return response - @log_function async def get_public_rooms( self, @@ -385,7 +398,7 @@ class TransportLayerClient: search_filter: Optional[Dict] = None, include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, - ): + ) -> JsonDict: """Get the list of public rooms from a remote homeserver See synapse.federation.federation_client.FederationClient.get_public_rooms for @@ -450,25 +463,27 @@ class TransportLayerClient: return response @log_function - async def exchange_third_party_invite(self, destination, room_id, event_dict): + async def exchange_third_party_invite( + self, destination: str, room_id: str, event_dict: JsonDict + ) -> JsonDict: path = _create_v1_path("/exchange_third_party_invite/%s", room_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=event_dict ) - return response - @log_function - async def get_event_auth(self, destination, room_id, event_id): + async def get_event_auth( + self, destination: str, room_id: str, event_id: str + ) -> JsonDict: path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) - content = await self.client.get_json(destination=destination, path=path) - - return content + return await self.client.get_json(destination=destination, path=path) @log_function - async def query_client_keys(self, destination, query_content, timeout): + async def query_client_keys( + self, destination: str, query_content: JsonDict, timeout: int + ) -> JsonDict: """Query the device keys for a list of user ids hosted on a remote server. @@ -496,20 +511,21 @@ class TransportLayerClient: } Args: - destination(str): The server to query. - query_content(dict): The user ids to query. + destination: The server to query. + query_content: The user ids to query. Returns: A dict containing device and cross-signing keys. """ path = _create_v1_path("/user/keys/query") - content = await self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) - return content @log_function - async def query_user_devices(self, destination, user_id, timeout): + async def query_user_devices( + self, destination: str, user_id: str, timeout: int + ) -> JsonDict: """Query the devices for a user id hosted on a remote server. Response: @@ -535,20 +551,21 @@ class TransportLayerClient: } Args: - destination(str): The server to query. - query_content(dict): The user ids to query. + destination: The server to query. + query_content: The user ids to query. Returns: A dict containing device and cross-signing keys. """ path = _create_v1_path("/user/devices/%s", user_id) - content = await self.client.get_json( + return await self.client.get_json( destination=destination, path=path, timeout=timeout ) - return content @log_function - async def claim_client_keys(self, destination, query_content, timeout): + async def claim_client_keys( + self, destination: str, query_content: JsonDict, timeout: int + ) -> JsonDict: """Claim one-time keys for a list of devices hosted on a remote server. Request: @@ -572,33 +589,32 @@ class TransportLayerClient: } Args: - destination(str): The server to query. - query_content(dict): The user ids to query. + destination: The server to query. + query_content: The user ids to query. Returns: A dict containing the one-time keys. """ path = _create_v1_path("/user/keys/claim") - content = await self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) - return content @log_function async def get_missing_events( self, - destination, - room_id, - earliest_events, - latest_events, - limit, - min_depth, - timeout, - ): + destination: str, + room_id: str, + earliest_events: Iterable[str], + latest_events: Iterable[str], + limit: int, + min_depth: int, + timeout: int, + ) -> JsonDict: path = _create_v1_path("/get_missing_events/%s", room_id) - content = await self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data={ @@ -610,14 +626,14 @@ class TransportLayerClient: timeout=timeout, ) - return content - @log_function - def get_group_profile(self, destination, group_id, requester_user_id): + async def get_group_profile( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get a group profile""" path = _create_v1_path("/groups/%s/profile", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -625,14 +641,16 @@ class TransportLayerClient: ) @log_function - def update_group_profile(self, destination, group_id, requester_user_id, content): + async def update_group_profile( + self, destination: str, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """Update a remote group profile Args: - destination (str) - group_id (str) - requester_user_id (str) - content (dict): The new profile of the group + destination + group_id + requester_user_id + content: The new profile of the group """ path = _create_v1_path("/groups/%s/profile", group_id) @@ -645,11 +663,13 @@ class TransportLayerClient: ) @log_function - def get_group_summary(self, destination, group_id, requester_user_id): + async def get_group_summary( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get a group summary""" path = _create_v1_path("/groups/%s/summary", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -657,24 +677,31 @@ class TransportLayerClient: ) @log_function - def get_rooms_in_group(self, destination, group_id, requester_user_id): + async def get_rooms_in_group( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get all rooms in a group""" path = _create_v1_path("/groups/%s/rooms", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, ignore_backoff=True, ) - def add_room_to_group( - self, destination, group_id, requester_user_id, room_id, content - ): + async def add_room_to_group( + self, + destination: str, + group_id: str, + requester_user_id: str, + room_id: str, + content: JsonDict, + ) -> JsonDict: """Add a room to a group""" path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -682,15 +709,21 @@ class TransportLayerClient: ignore_backoff=True, ) - def update_room_in_group( - self, destination, group_id, requester_user_id, room_id, config_key, content - ): + async def update_room_in_group( + self, + destination: str, + group_id: str, + requester_user_id: str, + room_id: str, + config_key: str, + content: JsonDict, + ) -> JsonDict: """Update room in group""" path = _create_v1_path( "/groups/%s/room/%s/config/%s", group_id, room_id, config_key ) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -698,11 +731,13 @@ class TransportLayerClient: ignore_backoff=True, ) - def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): + async def remove_room_from_group( + self, destination: str, group_id: str, requester_user_id: str, room_id: str + ) -> JsonDict: """Remove a room from a group""" path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) - return self.client.delete_json( + return await self.client.delete_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -710,11 +745,13 @@ class TransportLayerClient: ) @log_function - def get_users_in_group(self, destination, group_id, requester_user_id): + async def get_users_in_group( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get users in a group""" path = _create_v1_path("/groups/%s/users", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -722,11 +759,13 @@ class TransportLayerClient: ) @log_function - def get_invited_users_in_group(self, destination, group_id, requester_user_id): + async def get_invited_users_in_group( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get users that have been invited to a group""" path = _create_v1_path("/groups/%s/invited_users", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -734,16 +773,20 @@ class TransportLayerClient: ) @log_function - def accept_group_invite(self, destination, group_id, user_id, content): + async def accept_group_invite( + self, destination: str, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Accept a group invite""" path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def join_group(self, destination, group_id, user_id, content): + def join_group( + self, destination: str, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Attempts to join a group""" path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) @@ -752,13 +795,18 @@ class TransportLayerClient: ) @log_function - def invite_to_group( - self, destination, group_id, user_id, requester_user_id, content - ): + async def invite_to_group( + self, + destination: str, + group_id: str, + user_id: str, + requester_user_id: str, + content: JsonDict, + ) -> JsonDict: """Invite a user to a group""" path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -767,25 +815,32 @@ class TransportLayerClient: ) @log_function - def invite_to_group_notification(self, destination, group_id, user_id, content): + async def invite_to_group_notification( + self, destination: str, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Sent by group server to inform a user's server that they have been invited. """ path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def remove_user_from_group( - self, destination, group_id, requester_user_id, user_id, content - ): + async def remove_user_from_group( + self, + destination: str, + group_id: str, + requester_user_id: str, + user_id: str, + content: JsonDict, + ) -> JsonDict: """Remove a user from a group""" path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -794,35 +849,43 @@ class TransportLayerClient: ) @log_function - def remove_user_from_group_notification( - self, destination, group_id, user_id, content - ): + async def remove_user_from_group_notification( + self, destination: str, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Sent by group server to inform a user's server that they have been kicked from the group. """ path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def renew_group_attestation(self, destination, group_id, user_id, content): + async def renew_group_attestation( + self, destination: str, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Sent by either a group server or a user's server to periodically update the attestations """ path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def update_group_summary_room( - self, destination, group_id, user_id, room_id, category_id, content - ): + async def update_group_summary_room( + self, + destination: str, + group_id: str, + user_id: str, + room_id: str, + category_id: str, + content: JsonDict, + ) -> JsonDict: """Update a room entry in a group summary""" if category_id: path = _create_v1_path( @@ -834,7 +897,7 @@ class TransportLayerClient: else: path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": user_id}, @@ -843,9 +906,14 @@ class TransportLayerClient: ) @log_function - def delete_group_summary_room( - self, destination, group_id, user_id, room_id, category_id - ): + async def delete_group_summary_room( + self, + destination: str, + group_id: str, + user_id: str, + room_id: str, + category_id: str, + ) -> JsonDict: """Delete a room entry in a group summary""" if category_id: path = _create_v1_path( @@ -857,7 +925,7 @@ class TransportLayerClient: else: path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) - return self.client.delete_json( + return await self.client.delete_json( destination=destination, path=path, args={"requester_user_id": user_id}, @@ -865,11 +933,13 @@ class TransportLayerClient: ) @log_function - def get_group_categories(self, destination, group_id, requester_user_id): + async def get_group_categories( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get all categories in a group""" path = _create_v1_path("/groups/%s/categories", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -877,11 +947,13 @@ class TransportLayerClient: ) @log_function - def get_group_category(self, destination, group_id, requester_user_id, category_id): + async def get_group_category( + self, destination: str, group_id: str, requester_user_id: str, category_id: str + ) -> JsonDict: """Get category info in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -889,13 +961,18 @@ class TransportLayerClient: ) @log_function - def update_group_category( - self, destination, group_id, requester_user_id, category_id, content - ): + async def update_group_category( + self, + destination: str, + group_id: str, + requester_user_id: str, + category_id: str, + content: JsonDict, + ) -> JsonDict: """Update a category in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -904,13 +981,13 @@ class TransportLayerClient: ) @log_function - def delete_group_category( - self, destination, group_id, requester_user_id, category_id - ): + async def delete_group_category( + self, destination: str, group_id: str, requester_user_id: str, category_id: str + ) -> JsonDict: """Delete a category in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - return self.client.delete_json( + return await self.client.delete_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -918,11 +995,13 @@ class TransportLayerClient: ) @log_function - def get_group_roles(self, destination, group_id, requester_user_id): + async def get_group_roles( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get all roles in a group""" path = _create_v1_path("/groups/%s/roles", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -930,11 +1009,13 @@ class TransportLayerClient: ) @log_function - def get_group_role(self, destination, group_id, requester_user_id, role_id): + async def get_group_role( + self, destination: str, group_id: str, requester_user_id: str, role_id: str + ) -> JsonDict: """Get a roles info""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -942,13 +1023,18 @@ class TransportLayerClient: ) @log_function - def update_group_role( - self, destination, group_id, requester_user_id, role_id, content - ): + async def update_group_role( + self, + destination: str, + group_id: str, + requester_user_id: str, + role_id: str, + content: JsonDict, + ) -> JsonDict: """Update a role in a group""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -957,11 +1043,13 @@ class TransportLayerClient: ) @log_function - def delete_group_role(self, destination, group_id, requester_user_id, role_id): + async def delete_group_role( + self, destination: str, group_id: str, requester_user_id: str, role_id: str + ) -> JsonDict: """Delete a role in a group""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - return self.client.delete_json( + return await self.client.delete_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -969,9 +1057,15 @@ class TransportLayerClient: ) @log_function - def update_group_summary_user( - self, destination, group_id, requester_user_id, user_id, role_id, content - ): + async def update_group_summary_user( + self, + destination: str, + group_id: str, + requester_user_id: str, + user_id: str, + role_id: str, + content: JsonDict, + ) -> JsonDict: """Update a users entry in a group""" if role_id: path = _create_v1_path( @@ -980,7 +1074,7 @@ class TransportLayerClient: else: path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -989,11 +1083,13 @@ class TransportLayerClient: ) @log_function - def set_group_join_policy(self, destination, group_id, requester_user_id, content): + async def set_group_join_policy( + self, destination: str, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """Sets the join policy for a group""" path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) - return self.client.put_json( + return await self.client.put_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -1002,9 +1098,14 @@ class TransportLayerClient: ) @log_function - def delete_group_summary_user( - self, destination, group_id, requester_user_id, user_id, role_id - ): + async def delete_group_summary_user( + self, + destination: str, + group_id: str, + requester_user_id: str, + user_id: str, + role_id: str, + ) -> JsonDict: """Delete a users entry in a group""" if role_id: path = _create_v1_path( @@ -1013,33 +1114,35 @@ class TransportLayerClient: else: path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) - return self.client.delete_json( + return await self.client.delete_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, ignore_backoff=True, ) - def bulk_get_publicised_groups(self, destination, user_ids): + async def bulk_get_publicised_groups( + self, destination: str, user_ids: Iterable[str] + ) -> JsonDict: """Get the groups a list of users are publicising""" path = _create_v1_path("/get_groups_publicised") content = {"user_ids": user_ids} - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=content, ignore_backoff=True ) - def get_room_complexity(self, destination, room_id): + async def get_room_complexity(self, destination: str, room_id: str) -> JsonDict: """ Args: - destination (str): The remote server - room_id (str): The room ID to ask about. + destination: The remote server + room_id: The room ID to ask about. """ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/rooms/%s/complexity", room_id) - return self.client.get_json(destination=destination, path=path) + return await self.client.get_json(destination=destination, path=path) async def get_space_summary( self, @@ -1075,14 +1178,14 @@ class TransportLayerClient: ) -def _create_path(federation_prefix, path, *args): +def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ Ensures that all args are url encoded. """ return federation_prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args) -def _create_v1_path(path, *args): +def _create_v1_path(path: str, *args: str) -> str: """Creates a path against V1 federation API from the path template and args. Ensures that all args are url encoded. @@ -1091,16 +1194,13 @@ def _create_v1_path(path, *args): _create_v1_path("/event/%s", event_id) Args: - path (str): String template for the path - args: ([str]): Args to insert into path. Each arg will be url encoded - - Returns: - str + path: String template for the path + args: Args to insert into path. Each arg will be url encoded """ return _create_path(FEDERATION_V1_PREFIX, path, *args) -def _create_v2_path(path, *args): +def _create_v2_path(path: str, *args: str) -> str: """Creates a path against V2 federation API from the path template and args. Ensures that all args are url encoded. @@ -1109,11 +1209,8 @@ def _create_v2_path(path, *args): _create_v2_path("/event/%s", event_id) Args: - path (str): String template for the path - args: ([str]): Args to insert into path. Each arg will be url encoded - - Returns: - str + path: String template for the path + args: Args to insert into path. Each arg will be url encoded """ return _create_path(FEDERATION_V2_PREFIX, path, *args) -- cgit 1.4.1