summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/directory.py59
-rw-r--r--synapse/handlers/identity.py9
-rw-r--r--synapse/handlers/message.py24
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/handlers/ui_auth/checkers.py35
5 files changed, 76 insertions, 53 deletions
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 90932316f3..de1b14cde3 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -14,7 +14,7 @@
 
 import logging
 import string
-from typing import Iterable, List, Optional
+from typing import TYPE_CHECKING, Iterable, List, Optional
 
 from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
 from synapse.api.errors import (
@@ -27,15 +27,19 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.appservice import ApplicationService
-from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id
+from synapse.storage.databases.main.directory import RoomAliasMapping
+from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class DirectoryHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.state = hs.get_state_handler()
@@ -60,7 +64,7 @@ class DirectoryHandler(BaseHandler):
         room_id: str,
         servers: Optional[Iterable[str]] = None,
         creator: Optional[str] = None,
-    ):
+    ) -> None:
         # general association creation for both human users and app services
 
         for wchar in string.whitespace:
@@ -104,8 +108,9 @@ class DirectoryHandler(BaseHandler):
         """
 
         user_id = requester.user.to_string()
+        room_alias_str = room_alias.to_string()
 
-        if len(room_alias.to_string()) > MAX_ALIAS_LENGTH:
+        if len(room_alias_str) > MAX_ALIAS_LENGTH:
             raise SynapseError(
                 400,
                 "Can't create aliases longer than %s characters" % MAX_ALIAS_LENGTH,
@@ -114,7 +119,7 @@ class DirectoryHandler(BaseHandler):
 
         service = requester.app_service
         if service:
-            if not service.is_interested_in_alias(room_alias.to_string()):
+            if not service.is_interested_in_alias(room_alias_str):
                 raise SynapseError(
                     400,
                     "This application service has not reserved this kind of alias.",
@@ -138,7 +143,7 @@ class DirectoryHandler(BaseHandler):
                 raise AuthError(403, "This user is not permitted to create this alias")
 
             if not self.config.is_alias_creation_allowed(
-                user_id, room_id, room_alias.to_string()
+                user_id, room_id, room_alias_str
             ):
                 # Lets just return a generic message, as there may be all sorts of
                 # reasons why we said no. TODO: Allow configurable error messages
@@ -211,7 +216,7 @@ class DirectoryHandler(BaseHandler):
 
     async def delete_appservice_association(
         self, service: ApplicationService, room_alias: RoomAlias
-    ):
+    ) -> None:
         if not service.is_interested_in_alias(room_alias.to_string()):
             raise SynapseError(
                 400,
@@ -220,7 +225,7 @@ class DirectoryHandler(BaseHandler):
             )
         await self._delete_association(room_alias)
 
-    async def _delete_association(self, room_alias: RoomAlias):
+    async def _delete_association(self, room_alias: RoomAlias) -> str:
         if not self.hs.is_mine(room_alias):
             raise SynapseError(400, "Room alias must be local")
 
@@ -228,17 +233,19 @@ class DirectoryHandler(BaseHandler):
 
         return room_id
 
-    async def get_association(self, room_alias: RoomAlias):
+    async def get_association(self, room_alias: RoomAlias) -> JsonDict:
         room_id = None
         if self.hs.is_mine(room_alias):
-            result = await self.get_association_from_room_alias(room_alias)
+            result = await self.get_association_from_room_alias(
+                room_alias
+            )  # type: Optional[RoomAliasMapping]
 
             if result:
                 room_id = result.room_id
                 servers = result.servers
         else:
             try:
-                result = await self.federation.make_query(
+                fed_result = await self.federation.make_query(
                     destination=room_alias.domain,
                     query_type="directory",
                     args={"room_alias": room_alias.to_string()},
@@ -248,13 +255,13 @@ class DirectoryHandler(BaseHandler):
             except CodeMessageException as e:
                 logging.warning("Error retrieving alias")
                 if e.code == 404:
-                    result = None
+                    fed_result = None
                 else:
                     raise
 
-            if result and "room_id" in result and "servers" in result:
-                room_id = result["room_id"]
-                servers = result["servers"]
+            if fed_result and "room_id" in fed_result and "servers" in fed_result:
+                room_id = fed_result["room_id"]
+                servers = fed_result["servers"]
 
         if not room_id:
             raise SynapseError(
@@ -275,7 +282,7 @@ class DirectoryHandler(BaseHandler):
 
         return {"room_id": room_id, "servers": servers}
 
-    async def on_directory_query(self, args):
+    async def on_directory_query(self, args: JsonDict) -> JsonDict:
         room_alias = RoomAlias.from_string(args["room_alias"])
         if not self.hs.is_mine(room_alias):
             raise SynapseError(400, "Room Alias is not hosted on this homeserver")
@@ -293,7 +300,7 @@ class DirectoryHandler(BaseHandler):
 
     async def _update_canonical_alias(
         self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
-    ):
+    ) -> None:
         """
         Send an updated canonical alias event if the removed alias was set as
         the canonical alias or listed in the alt_aliases field.
@@ -344,7 +351,9 @@ class DirectoryHandler(BaseHandler):
                 ratelimit=False,
             )
 
-    async def get_association_from_room_alias(self, room_alias: RoomAlias):
+    async def get_association_from_room_alias(
+        self, room_alias: RoomAlias
+    ) -> Optional[RoomAliasMapping]:
         result = await self.store.get_association_from_room_alias(room_alias)
         if not result:
             # Query AS to see if it exists
@@ -372,7 +381,7 @@ class DirectoryHandler(BaseHandler):
         # either no interested services, or no service with an exclusive lock
         return True
 
-    async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
+    async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool:
         """Determine whether a user can delete an alias.
 
         One of the following must be true:
@@ -394,14 +403,13 @@ class DirectoryHandler(BaseHandler):
         if not room_id:
             return False
 
-        res = await self.auth.check_can_change_room_list(
+        return await self.auth.check_can_change_room_list(
             room_id, UserID.from_string(user_id)
         )
-        return res
 
     async def edit_published_room_list(
         self, requester: Requester, room_id: str, visibility: str
-    ):
+    ) -> None:
         """Edit the entry of the room in the published room list.
 
         requester
@@ -469,7 +477,7 @@ class DirectoryHandler(BaseHandler):
 
     async def edit_published_appservice_room_list(
         self, appservice_id: str, network_id: str, room_id: str, visibility: str
-    ):
+    ) -> None:
         """Add or remove a room from the appservice/network specific public
         room list.
 
@@ -499,5 +507,4 @@ class DirectoryHandler(BaseHandler):
                 room_id, requester.user.to_string()
             )
 
-        aliases = await self.store.get_aliases_for_room(room_id)
-        return aliases
+        return await self.store.get_aliases_for_room(room_id)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 0b3b1fadb5..33d16fbf9c 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -17,7 +17,7 @@
 """Utilities for interacting with Identity Servers"""
 import logging
 import urllib.parse
-from typing import Awaitable, Callable, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
 
 from synapse.api.errors import (
     CodeMessageException,
@@ -41,13 +41,16 @@ from synapse.util.stringutils import (
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 id_server_scheme = "https://"
 
 
 class IdentityHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         # An HTTP client for contacting trusted URLs.
@@ -80,7 +83,7 @@ class IdentityHandler(BaseHandler):
         request: SynapseRequest,
         medium: str,
         address: str,
-    ):
+    ) -> None:
         """Used to ratelimit requests to `/requestToken` by IP and address.
 
         Args:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ec8eb21674..49f8aa25ea 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 import logging
 import random
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
@@ -66,7 +66,7 @@ logger = logging.getLogger(__name__)
 class MessageHandler:
     """Contains some read only APIs to get state about a room"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.state = hs.get_state_handler()
@@ -91,7 +91,7 @@ class MessageHandler:
         room_id: str,
         event_type: str,
         state_key: str,
-    ) -> dict:
+    ) -> Optional[EventBase]:
         """Get data from a room.
 
         Args:
@@ -115,6 +115,10 @@ class MessageHandler:
             data = await self.state.get_current_state(room_id, event_type, state_key)
         elif membership == Membership.LEAVE:
             key = (event_type, state_key)
+            # If the membership is not JOIN, then the event ID should exist.
+            assert (
+                membership_event_id is not None
+            ), "check_user_in_room_or_world_readable returned invalid data"
             room_state = await self.state_store.get_state_for_events(
                 [membership_event_id], StateFilter.from_types([key])
             )
@@ -186,10 +190,12 @@ class MessageHandler:
 
             event = last_events[0]
             if visible_events:
-                room_state = await self.state_store.get_state_for_events(
+                room_state_events = await self.state_store.get_state_for_events(
                     [event.event_id], state_filter=state_filter
                 )
-                room_state = room_state[event.event_id]
+                room_state = room_state_events[
+                    event.event_id
+                ]  # type: Mapping[Any, EventBase]
             else:
                 raise AuthError(
                     403,
@@ -210,10 +216,14 @@ class MessageHandler:
                 )
                 room_state = await self.store.get_events(state_ids.values())
             elif membership == Membership.LEAVE:
-                room_state = await self.state_store.get_state_for_events(
+                # If the membership is not JOIN, then the event ID should exist.
+                assert (
+                    membership_event_id is not None
+                ), "check_user_in_room_or_world_readable returned invalid data"
+                room_state_events = await self.state_store.get_state_for_events(
                     [membership_event_id], state_filter=state_filter
                 )
-                room_state = room_state[membership_event_id]
+                room_state = room_state_events[membership_event_id]
 
         now = self.clock.time_msec()
         events = await self._event_serializer.serialize_events(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 2c5bada1d8..20700fc5a8 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1044,7 +1044,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
 
 class RoomMemberMasterHandler(RoomMemberHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.distributor = hs.get_distributor()
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 0eeb7c03f2..5414ce77d8 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
 from twisted.web.client import PartialDownloadError
 
@@ -22,13 +22,16 @@ from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.util import json_decoder
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class UserInteractiveAuthChecker:
     """Abstract base class for an interactive auth checker"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         pass
 
     def is_enabled(self) -> bool:
@@ -57,10 +60,10 @@ class UserInteractiveAuthChecker:
 class DummyAuthChecker(UserInteractiveAuthChecker):
     AUTH_TYPE = LoginType.DUMMY
 
-    def is_enabled(self):
+    def is_enabled(self) -> bool:
         return True
 
-    async def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return True
 
 
@@ -70,24 +73,24 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
     def is_enabled(self):
         return True
 
-    async def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return True
 
 
 class RecaptchaAuthChecker(UserInteractiveAuthChecker):
     AUTH_TYPE = LoginType.RECAPTCHA
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._enabled = bool(hs.config.recaptcha_private_key)
         self._http_client = hs.get_proxied_http_client()
         self._url = hs.config.recaptcha_siteverify_api
         self._secret = hs.config.recaptcha_private_key
 
-    def is_enabled(self):
+    def is_enabled(self) -> bool:
         return self._enabled
 
-    async def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         try:
             user_response = authdict["response"]
         except KeyError:
@@ -132,11 +135,11 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
 
 
 class _BaseThreepidAuthChecker:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
 
-    async def _check_threepid(self, medium, authdict):
+    async def _check_threepid(self, medium: str, authdict: dict) -> dict:
         if "threepid_creds" not in authdict:
             raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
 
@@ -206,31 +209,31 @@ class _BaseThreepidAuthChecker:
 class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
     AUTH_TYPE = LoginType.EMAIL_IDENTITY
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         UserInteractiveAuthChecker.__init__(self, hs)
         _BaseThreepidAuthChecker.__init__(self, hs)
 
-    def is_enabled(self):
+    def is_enabled(self) -> bool:
         return self.hs.config.threepid_behaviour_email in (
             ThreepidBehaviour.REMOTE,
             ThreepidBehaviour.LOCAL,
         )
 
-    async def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return await self._check_threepid("email", authdict)
 
 
 class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
     AUTH_TYPE = LoginType.MSISDN
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         UserInteractiveAuthChecker.__init__(self, hs)
         _BaseThreepidAuthChecker.__init__(self, hs)
 
-    def is_enabled(self):
+    def is_enabled(self) -> bool:
         return bool(self.hs.config.account_threepid_delegate_msisdn)
 
-    async def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return await self._check_threepid("msisdn", authdict)