summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-05-15 15:05:25 -0400
committerGitHub <noreply@github.com>2020-05-15 15:05:25 -0400
commitc29915bd05513a329e099d7e2970768113595830 (patch)
tree290b78204262a4f447d4e93274ccb6a8aa66e9ef
parentFix limit logic for AccountDataStream (#7384) (diff)
downloadsynapse-c29915bd05513a329e099d7e2970768113595830.tar.xz
Add type hints to room member handlers (#7513)
-rw-r--r--changelog.d/7513.misc1
-rw-r--r--synapse/handlers/room_member.py284
-rw-r--r--synapse/handlers/room_member_worker.py28
-rw-r--r--tox.ini2
4 files changed, 176 insertions, 139 deletions
diff --git a/changelog.d/7513.misc b/changelog.d/7513.misc
new file mode 100644
index 0000000000..2ea7373e29
--- /dev/null
+++ b/changelog.d/7513.misc
@@ -0,0 +1 @@
+Add type hints to room member handler.
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4ddeba4c97..e51e1c32fe 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -17,13 +17,16 @@
 
 import abc
 import logging
+from typing import Dict, Iterable, List, Optional, Tuple, Union
 
 from six.moves import http_client
 
 from synapse import types
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, Codes, SynapseError
-from synapse.types import Collection, RoomID, UserID
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room, user_left_room
 
@@ -74,84 +77,84 @@ class RoomMemberHandler(object):
         self.base_handler = BaseHandler(hs)
 
     @abc.abstractmethod
-    async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+    async def _remote_join(
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        user: UserID,
+        content: dict,
+    ) -> Optional[dict]:
         """Try and join a room that this server is not in
 
         Args:
-            requester (Requester)
-            remote_room_hosts (list[str]): List of servers that can be used
-                to join via.
-            room_id (str): Room that we are trying to join
-            user (UserID): User who is trying to join
-            content (dict): A dict that should be used as the content of the
-                join event.
-
-        Returns:
-            Deferred
+            requester
+            remote_room_hosts: List of servers that can be used to join via.
+            room_id: Room that we are trying to join
+            user: User who is trying to join
+            content: A dict that should be used as the content of the join event.
         """
         raise NotImplementedError()
 
     @abc.abstractmethod
     async def _remote_reject_invite(
-        self, requester, remote_room_hosts, room_id, target, content
-    ):
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        target: UserID,
+        content: dict,
+    ) -> dict:
         """Attempt to reject an invite for a room this server is not in. If we
         fail to do so we locally mark the invite as rejected.
 
         Args:
-            requester (Requester)
-            remote_room_hosts (list[str]): List of servers to use to try and
-                reject invite
-            room_id (str)
-            target (UserID): The user rejecting the invite
-            content (dict): The content for the rejection event
+            requester
+            remote_room_hosts: List of servers to use to try and reject invite
+            room_id
+            target: The user rejecting the invite
+            content: The content for the rejection event
 
         Returns:
-            Deferred[dict]: A dictionary to be returned to the client, may
+            A dictionary to be returned to the client, may
             include event_id etc, or nothing if we locally rejected
         """
         raise NotImplementedError()
 
     @abc.abstractmethod
-    async def _user_joined_room(self, target, room_id):
+    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
         """Notifies distributor on master process that the user has joined the
         room.
 
         Args:
-            target (UserID)
-            room_id (str)
-
-        Returns:
-            None
+            target
+            room_id
         """
         raise NotImplementedError()
 
     @abc.abstractmethod
-    async def _user_left_room(self, target, room_id):
+    async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Notifies distributor on master process that the user has left the
         room.
 
         Args:
-            target (UserID)
-            room_id (str)
-
-        Returns:
-            None
+            target
+            room_id
         """
         raise NotImplementedError()
 
     async def _local_membership_update(
         self,
-        requester,
-        target,
-        room_id,
-        membership,
+        requester: Requester,
+        target: UserID,
+        room_id: str,
+        membership: str,
         prev_event_ids: Collection[str],
-        txn_id=None,
-        ratelimit=True,
-        content=None,
-        require_consent=True,
-    ):
+        txn_id: Optional[str] = None,
+        ratelimit: bool = True,
+        content: Optional[dict] = None,
+        require_consent: bool = True,
+    ) -> EventBase:
         user_id = target.to_string()
 
         if content is None:
@@ -214,16 +217,13 @@ class RoomMemberHandler(object):
 
     async def copy_room_tags_and_direct_to_room(
         self, old_room_id, new_room_id, user_id
-    ):
+    ) -> None:
         """Copies the tags and direct room state from one room to another.
 
         Args:
-            old_room_id (str)
-            new_room_id (str)
-            user_id (str)
-
-        Returns:
-            Deferred[None]
+            old_room_id: The room ID of the old room.
+            new_room_id: The room ID of the new room.
+            user_id: The user's ID.
         """
         # Retrieve user account data for predecessor room
         user_account_data, _ = await self.store.get_account_data_for_user(user_id)
@@ -253,17 +253,17 @@ class RoomMemberHandler(object):
 
     async def update_membership(
         self,
-        requester,
-        target,
-        room_id,
-        action,
-        txn_id=None,
-        remote_room_hosts=None,
-        third_party_signed=None,
-        ratelimit=True,
-        content=None,
-        require_consent=True,
-    ):
+        requester: Requester,
+        target: UserID,
+        room_id: str,
+        action: str,
+        txn_id: Optional[str] = None,
+        remote_room_hosts: Optional[List[str]] = None,
+        third_party_signed: Optional[dict] = None,
+        ratelimit: bool = True,
+        content: Optional[dict] = None,
+        require_consent: bool = True,
+    ) -> Union[EventBase, Optional[dict]]:
         key = (room_id,)
 
         with (await self.member_linearizer.queue(key)):
@@ -284,17 +284,17 @@ class RoomMemberHandler(object):
 
     async def _update_membership(
         self,
-        requester,
-        target,
-        room_id,
-        action,
-        txn_id=None,
-        remote_room_hosts=None,
-        third_party_signed=None,
-        ratelimit=True,
-        content=None,
-        require_consent=True,
-    ):
+        requester: Requester,
+        target: UserID,
+        room_id: str,
+        action: str,
+        txn_id: Optional[str] = None,
+        remote_room_hosts: Optional[List[str]] = None,
+        third_party_signed: Optional[dict] = None,
+        ratelimit: bool = True,
+        content: Optional[dict] = None,
+        require_consent: bool = True,
+    ) -> Union[EventBase, Optional[dict]]:
         content_specified = bool(content)
         if content is None:
             content = {}
@@ -468,12 +468,11 @@ class RoomMemberHandler(object):
                 else:
                     # send the rejection to the inviter's HS.
                     remote_room_hosts = remote_room_hosts + [inviter.domain]
-                    res = await self._remote_reject_invite(
+                    return await self._remote_reject_invite(
                         requester, remote_room_hosts, room_id, target, content,
                     )
-                    return res
 
-        res = await self._local_membership_update(
+        return await self._local_membership_update(
             requester=requester,
             target=target,
             room_id=room_id,
@@ -484,9 +483,10 @@ class RoomMemberHandler(object):
             content=content,
             require_consent=require_consent,
         )
-        return res
 
-    async def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
+    async def transfer_room_state_on_room_upgrade(
+        self, old_room_id: str, room_id: str
+    ) -> None:
         """Upon our server becoming aware of an upgraded room, either by upgrading a room
         ourselves or joining one, we can transfer over information from the previous room.
 
@@ -494,12 +494,8 @@ class RoomMemberHandler(object):
         well as migrating the room directory state.
 
         Args:
-            old_room_id (str): The ID of the old room
-
-            room_id (str): The ID of the new room
-
-        Returns:
-            Deferred
+            old_room_id: The ID of the old room
+            room_id: The ID of the new room
         """
         logger.info("Transferring room state from %s to %s", old_room_id, room_id)
 
@@ -526,17 +522,16 @@ class RoomMemberHandler(object):
             # Remove the old room from those groups
             await self.store.remove_room_from_group(group_id, old_room_id)
 
-    async def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
+    async def copy_user_state_on_room_upgrade(
+        self, old_room_id: str, new_room_id: str, user_ids: Iterable[str]
+    ) -> None:
         """Copy user-specific information when they join a new room when that new room is the
         result of a room upgrade
 
         Args:
-            old_room_id (str): The ID of upgraded room
-            new_room_id (str): The ID of the new room
-            user_ids (Iterable[str]): User IDs to copy state for
-
-        Returns:
-            Deferred
+            old_room_id: The ID of upgraded room
+            new_room_id: The ID of the new room
+            user_ids: User IDs to copy state for
         """
 
         logger.debug(
@@ -566,17 +561,23 @@ class RoomMemberHandler(object):
                 )
                 continue
 
-    async def send_membership_event(self, requester, event, context, ratelimit=True):
+    async def send_membership_event(
+        self,
+        requester: Requester,
+        event: EventBase,
+        context: EventContext,
+        ratelimit: bool = True,
+    ):
         """
         Change the membership status of a user in a room.
 
         Args:
-            requester (Requester): The local user who requested the membership
+            requester: The local user who requested the membership
                 event. If None, certain checks, like whether this homeserver can
                 act as the sender, will be skipped.
-            event (SynapseEvent): The membership event.
+            event: The membership event.
             context: The context of the event.
-            ratelimit (bool): Whether to rate limit this request.
+            ratelimit: Whether to rate limit this request.
         Raises:
             SynapseError if there was a problem changing the membership.
         """
@@ -636,7 +637,9 @@ class RoomMemberHandler(object):
                 if prev_member_event.membership == Membership.JOIN:
                     await self._user_left_room(target_user, room_id)
 
-    async def _can_guest_join(self, current_state_ids):
+    async def _can_guest_join(
+        self, current_state_ids: Dict[Tuple[str, str], str]
+    ) -> bool:
         """
         Returns whether a guest can join a room based on its current state.
         """
@@ -653,12 +656,14 @@ class RoomMemberHandler(object):
             and guest_access.content["guest_access"] == "can_join"
         )
 
-    async def lookup_room_alias(self, room_alias):
+    async def lookup_room_alias(
+        self, room_alias: RoomAlias
+    ) -> Tuple[RoomID, List[str]]:
         """
         Get the room ID associated with a room alias.
 
         Args:
-            room_alias (RoomAlias): The alias to look up.
+            room_alias: The alias to look up.
         Returns:
             A tuple of:
                 The room ID as a RoomID object.
@@ -682,24 +687,25 @@ class RoomMemberHandler(object):
 
         return RoomID.from_string(room_id), servers
 
-    async def _get_inviter(self, user_id, room_id):
+    async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]:
         invite = await self.store.get_invite_for_local_user_in_room(
             user_id=user_id, room_id=room_id
         )
         if invite:
             return UserID.from_string(invite.sender)
+        return None
 
     async def do_3pid_invite(
         self,
-        room_id,
-        inviter,
-        medium,
-        address,
-        id_server,
-        requester,
-        txn_id,
-        id_access_token=None,
-    ):
+        room_id: str,
+        inviter: UserID,
+        medium: str,
+        address: str,
+        id_server: str,
+        requester: Requester,
+        txn_id: Optional[str],
+        id_access_token: Optional[str] = None,
+    ) -> None:
         if self.config.block_non_admin_invites:
             is_requester_admin = await self.auth.is_server_admin(requester.user)
             if not is_requester_admin:
@@ -748,15 +754,15 @@ class RoomMemberHandler(object):
 
     async def _make_and_store_3pid_invite(
         self,
-        requester,
-        id_server,
-        medium,
-        address,
-        room_id,
-        user,
-        txn_id,
-        id_access_token=None,
-    ):
+        requester: Requester,
+        id_server: str,
+        medium: str,
+        address: str,
+        room_id: str,
+        user: UserID,
+        txn_id: Optional[str],
+        id_access_token: Optional[str] = None,
+    ) -> None:
         room_state = await self.state_handler.get_current_state(room_id)
 
         inviter_display_name = ""
@@ -830,7 +836,9 @@ class RoomMemberHandler(object):
             txn_id=txn_id,
         )
 
-    async def _is_host_in_room(self, current_state_ids):
+    async def _is_host_in_room(
+        self, current_state_ids: Dict[Tuple[str, str], str]
+    ) -> bool:
         # Have we just created the room, and is this about to be the very
         # first member event?
         create_event_id = current_state_ids.get(("m.room.create", ""))
@@ -852,7 +860,7 @@ class RoomMemberHandler(object):
 
         return False
 
-    async def _is_server_notice_room(self, room_id):
+    async def _is_server_notice_room(self, room_id: str) -> bool:
         if self._server_notices_mxid is None:
             return False
         user_ids = await self.store.get_users_in_room(room_id)
@@ -867,13 +875,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         self.distributor.declare("user_joined_room")
         self.distributor.declare("user_left_room")
 
-    async def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
+    async def _is_remote_room_too_complex(
+        self, room_id: str, remote_room_hosts: List[str]
+    ) -> Optional[bool]:
         """
         Check if complexity of a remote room is too great.
 
         Args:
-            room_id (str)
-            remote_room_hosts (list[str])
+            room_id
+            remote_room_hosts
 
         Returns: bool of whether the complexity is too great, or None
             if unable to be fetched
@@ -887,21 +897,26 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             return complexity["v1"] > max_complexity
         return None
 
-    async def _is_local_room_too_complex(self, room_id):
+    async def _is_local_room_too_complex(self, room_id: str) -> bool:
         """
         Check if the complexity of a local room is too great.
 
         Args:
-            room_id (str)
-
-        Returns: bool
+            room_id: The room ID to check for complexity.
         """
         max_complexity = self.hs.config.limit_remote_rooms.complexity
         complexity = await self.store.get_room_complexity(room_id)
 
         return complexity["v1"] > max_complexity
 
-    async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+    async def _remote_join(
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        user: UserID,
+        content: dict,
+    ) -> None:
         """Implements RoomMemberHandler._remote_join
         """
         # filter ourselves out of remote_room_hosts: do_invite_join ignores it
@@ -961,8 +976,13 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             )
 
     async def _remote_reject_invite(
-        self, requester, remote_room_hosts, room_id, target, content
-    ):
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        target: UserID,
+        content: dict,
+    ) -> dict:
         """Implements RoomMemberHandler._remote_reject_invite
         """
         fed_handler = self.federation_handler
@@ -983,17 +1003,17 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             await self.store.locally_reject_invite(target.to_string(), room_id)
             return {}
 
-    async def _user_joined_room(self, target, room_id):
+    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_joined_room
         """
-        return user_joined_room(self.distributor, target, room_id)
+        user_joined_room(self.distributor, target, room_id)
 
-    async def _user_left_room(self, target, room_id):
+    async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
-        return user_left_room(self.distributor, target, room_id)
+        user_left_room(self.distributor, target, room_id)
 
-    async def forget(self, user, room_id):
+    async def forget(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
 
         member = await self.state_handler.get_current_state(
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 0fc54349ab..5c776cc0be 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import List, Optional
 
 from synapse.api.errors import SynapseError
 from synapse.handlers.room_member import RoomMemberHandler
@@ -22,6 +23,7 @@ from synapse.replication.http.membership import (
     ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
     ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
 )
+from synapse.types import Requester, UserID
 
 logger = logging.getLogger(__name__)
 
@@ -34,7 +36,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         self._remote_reject_client = ReplRejectInvite.make_client(hs)
         self._notify_change_client = ReplJoinedLeft.make_client(hs)
 
-    async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+    async def _remote_join(
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        user: UserID,
+        content: dict,
+    ) -> Optional[dict]:
         """Implements RoomMemberHandler._remote_join
         """
         if len(remote_room_hosts) == 0:
@@ -53,8 +62,13 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         return ret
 
     async def _remote_reject_invite(
-        self, requester, remote_room_hosts, room_id, target, content
-    ):
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        target: UserID,
+        content: dict,
+    ) -> dict:
         """Implements RoomMemberHandler._remote_reject_invite
         """
         return await self._remote_reject_client(
@@ -65,16 +79,16 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
             content=content,
         )
 
-    async def _user_joined_room(self, target, room_id):
+    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_joined_room
         """
-        return await self._notify_change_client(
+        await self._notify_change_client(
             user_id=target.to_string(), room_id=room_id, change="joined"
         )
 
-    async def _user_left_room(self, target, room_id):
+    async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
-        return await self._notify_change_client(
+        await self._notify_change_client(
             user_id=target.to_string(), room_id=room_id, change="left"
         )
diff --git a/tox.ini b/tox.ini
index a69bc04334..5a1fa610b6 100644
--- a/tox.ini
+++ b/tox.ini
@@ -188,6 +188,8 @@ commands = mypy \
             synapse/handlers/directory.py \
             synapse/handlers/oidc_handler.py \
             synapse/handlers/presence.py \
+            synapse/handlers/room_member.py \
+            synapse/handlers/room_member_worker.py \
             synapse/handlers/saml_handler.py \
             synapse/handlers/sync.py \
             synapse/handlers/ui_auth \