summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-10-21 06:44:31 -0400
committerGitHub <noreply@github.com>2020-10-21 06:44:31 -0400
commitde5cafe980391ae6e2de1d38ac4e42dea182a304 (patch)
tree6b7a69562b72b7163363eea9811623bfc3ebef98
parentConsistently use wrap_as_background_task in more places (#8599) (diff)
downloadsynapse-de5cafe980391ae6e2de1d38ac4e42dea182a304.tar.xz
Add type hints to profile and base handlers. (#8609)
Diffstat (limited to '')
-rw-r--r--changelog.d/8609.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/handlers/_base.py20
-rw-r--r--synapse/handlers/initial_sync.py8
-rw-r--r--synapse/handlers/profile.py74
-rw-r--r--synapse/storage/databases/main/profile.py6
6 files changed, 72 insertions, 41 deletions
diff --git a/changelog.d/8609.misc b/changelog.d/8609.misc
new file mode 100644
index 0000000000..5e3f3c1993
--- /dev/null
+++ b/changelog.d/8609.misc
@@ -0,0 +1 @@
+Add type hints to profile and base handler.
diff --git a/mypy.ini b/mypy.ini
index b5db54ee3b..5e9f7b1259 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -15,8 +15,9 @@ files =
   synapse/events/builder.py,
   synapse/events/spamcheck.py,
   synapse/federation,
-  synapse/handlers/appservice.py,
+  synapse/handlers/_base.py,
   synapse/handlers/account_data.py,
+  synapse/handlers/appservice.py,
   synapse/handlers/auth.py,
   synapse/handlers/cas_handler.py,
   synapse/handlers/deactivate_account.py,
@@ -32,6 +33,7 @@ files =
   synapse/handlers/pagination.py,
   synapse/handlers/password_policy.py,
   synapse/handlers/presence.py,
+  synapse/handlers/profile.py,
   synapse/handlers/read_marker.py,
   synapse/handlers/room.py,
   synapse/handlers/room_member.py,
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 0206320e96..bd8e71ae56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
 
 import synapse.state
 import synapse.storage
@@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.types import UserID
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -30,11 +34,7 @@ class BaseHandler:
     Common base class for the event handlers.
     """
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()  # type: synapse.storage.DataStore
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
@@ -56,7 +56,7 @@ class BaseHandler:
                 clock=self.clock,
                 rate_hz=self.hs.config.rc_admin_redaction.per_second,
                 burst_count=self.hs.config.rc_admin_redaction.burst_count,
-            )
+            )  # type: Optional[Ratelimiter]
         else:
             self.admin_redaction_ratelimiter = None
 
@@ -127,15 +127,15 @@ class BaseHandler:
             if guest_access != "can_join":
                 if context:
                     current_state_ids = await context.get_current_state_ids()
-                    current_state = await self.store.get_events(
+                    current_state_dict = await self.store.get_events(
                         list(current_state_ids.values())
                     )
+                    current_state = list(current_state_dict.values())
                 else:
-                    current_state = await self.state_handler.get_current_state(
+                    current_state_map = await self.state_handler.get_current_state(
                         event.room_id
                     )
-
-                current_state = list(current_state.values())
+                    current_state = list(current_state_map.values())
 
                 logger.info("maybe_kick_guest_users %r", current_state)
                 await self.kick_guest_users(current_state)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 98075f48d2..cb11754bf8 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -293,6 +293,10 @@ class InitialSyncHandler(BaseHandler):
                 user_id, room_id, pagin_config, membership, is_peeking
             )
         elif membership == Membership.LEAVE:
+            # The member_event_id will always be available if membership is set
+            # to leave.
+            assert member_event_id
+
             result = await self._room_initial_sync_parted(
                 user_id, room_id, pagin_config, membership, member_event_id, is_peeking
             )
@@ -315,7 +319,7 @@ class InitialSyncHandler(BaseHandler):
         user_id: str,
         room_id: str,
         pagin_config: PaginationConfig,
-        membership: Membership,
+        membership: str,
         member_event_id: str,
         is_peeking: bool,
     ) -> JsonDict:
@@ -367,7 +371,7 @@ class InitialSyncHandler(BaseHandler):
         user_id: str,
         room_id: str,
         pagin_config: PaginationConfig,
-        membership: Membership,
+        membership: str,
         is_peeking: bool,
     ) -> JsonDict:
         current_state = await self.state.get_current_state(room_id=room_id)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index b78a12ad01..92700b589c 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -12,9 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import random
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import (
     AuthError,
@@ -25,10 +25,19 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.types import UserID, create_requester, get_domain_from_id
+from synapse.types import (
+    JsonDict,
+    Requester,
+    UserID,
+    create_requester,
+    get_domain_from_id,
+)
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 MAX_DISPLAYNAME_LEN = 256
@@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler):
     PROFILE_UPDATE_MS = 60 * 1000
     PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.federation = hs.get_federation_client()
@@ -60,7 +69,7 @@ class ProfileHandler(BaseHandler):
                 self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
             )
 
-    async def get_profile(self, user_id):
+    async def get_profile(self, user_id: str) -> JsonDict:
         target_user = UserID.from_string(user_id)
 
         if self.hs.is_mine(target_user):
@@ -91,7 +100,7 @@ class ProfileHandler(BaseHandler):
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
-    async def get_profile_from_cache(self, user_id):
+    async def get_profile_from_cache(self, user_id: str) -> JsonDict:
         """Get the profile information from our local cache. If the user is
         ours then the profile information will always be corect. Otherwise,
         it may be out of date/missing.
@@ -115,7 +124,7 @@ class ProfileHandler(BaseHandler):
             profile = await self.store.get_from_remote_profile_cache(user_id)
             return profile or {}
 
-    async def get_displayname(self, target_user):
+    async def get_displayname(self, target_user: UserID) -> str:
         if self.hs.is_mine(target_user):
             try:
                 displayname = await self.store.get_profile_displayname(
@@ -143,15 +152,19 @@ class ProfileHandler(BaseHandler):
             return result["displayname"]
 
     async def set_displayname(
-        self, target_user, requester, new_displayname, by_admin=False
-    ):
+        self,
+        target_user: UserID,
+        requester: Requester,
+        new_displayname: str,
+        by_admin: bool = False,
+    ) -> None:
         """Set the displayname of a user
 
         Args:
-            target_user (UserID): the user whose displayname is to be changed.
-            requester (Requester): The user attempting to make this change.
-            new_displayname (str): The displayname to give this user.
-            by_admin (bool): Whether this change was made by an administrator.
+            target_user: the user whose displayname is to be changed.
+            requester: The user attempting to make this change.
+            new_displayname: The displayname to give this user.
+            by_admin: Whether this change was made by an administrator.
         """
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -176,8 +189,9 @@ class ProfileHandler(BaseHandler):
                 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
             )
 
+        displayname_to_set = new_displayname  # type: Optional[str]
         if new_displayname == "":
-            new_displayname = None
+            displayname_to_set = None
 
         # If the admin changes the display name of a user, the requesting user cannot send
         # the join event to update the displayname in the rooms.
@@ -185,7 +199,9 @@ class ProfileHandler(BaseHandler):
         if by_admin:
             requester = create_requester(target_user)
 
-        await self.store.set_profile_displayname(target_user.localpart, new_displayname)
+        await self.store.set_profile_displayname(
+            target_user.localpart, displayname_to_set
+        )
 
         if self.hs.config.user_directory_search_all_users:
             profile = await self.store.get_profileinfo(target_user.localpart)
@@ -195,7 +211,7 @@ class ProfileHandler(BaseHandler):
 
         await self._update_join_states(requester, target_user)
 
-    async def get_avatar_url(self, target_user):
+    async def get_avatar_url(self, target_user: UserID) -> str:
         if self.hs.is_mine(target_user):
             try:
                 avatar_url = await self.store.get_profile_avatar_url(
@@ -222,15 +238,19 @@ class ProfileHandler(BaseHandler):
             return result["avatar_url"]
 
     async def set_avatar_url(
-        self, target_user, requester, new_avatar_url, by_admin=False
+        self,
+        target_user: UserID,
+        requester: Requester,
+        new_avatar_url: str,
+        by_admin: bool = False,
     ):
         """Set a new avatar URL for a user.
 
         Args:
-            target_user (UserID): the user whose avatar URL is to be changed.
-            requester (Requester): The user attempting to make this change.
-            new_avatar_url (str): The avatar URL to give this user.
-            by_admin (bool): Whether this change was made by an administrator.
+            target_user: the user whose avatar URL is to be changed.
+            requester: The user attempting to make this change.
+            new_avatar_url: The avatar URL to give this user.
+            by_admin: Whether this change was made by an administrator.
         """
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -267,7 +287,7 @@ class ProfileHandler(BaseHandler):
 
         await self._update_join_states(requester, target_user)
 
-    async def on_profile_query(self, args):
+    async def on_profile_query(self, args: JsonDict) -> JsonDict:
         user = UserID.from_string(args["user_id"])
         if not self.hs.is_mine(user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -292,7 +312,9 @@ class ProfileHandler(BaseHandler):
 
         return response
 
-    async def _update_join_states(self, requester, target_user):
+    async def _update_join_states(
+        self, requester: Requester, target_user: UserID
+    ) -> None:
         if not self.hs.is_mine(target_user):
             return
 
@@ -323,15 +345,17 @@ class ProfileHandler(BaseHandler):
                     "Failed to update join event for room %s - %s", room_id, str(e)
                 )
 
-    async def check_profile_query_allowed(self, target_user, requester=None):
+    async def check_profile_query_allowed(
+        self, target_user: UserID, requester: Optional[UserID] = None
+    ) -> None:
         """Checks whether a profile query is allowed. If the
         'require_auth_for_profile_requests' config flag is set to True and a
         'requester' is provided, the query is only allowed if the two users
         share a room.
 
         Args:
-            target_user (UserID): The owner of the queried profile.
-            requester (None|UserID): The user querying for the profile.
+            target_user: The owner of the queried profile.
+            requester: The user querying for the profile.
 
         Raises:
             SynapseError(403): The two users share no room, or ne user couldn't
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 1681caa1f0..a6d1eb908a 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
@@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     async def set_profile_displayname(
-        self, user_localpart: str, new_displayname: str
+        self, user_localpart: str, new_displayname: Optional[str]
     ) -> None:
         await self.db_pool.simple_update_one(
             table="profiles",
@@ -144,7 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
 
     async def get_remote_profile_cache_entries_that_expire(
         self, last_checked: int
-    ) -> Dict[str, str]:
+    ) -> List[Dict[str, str]]:
         """Get all users who haven't been checked since `last_checked`
         """