summary refs log tree commit diff
path: root/synapse/rest/client/account.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-09-01 11:59:32 -0400
committerGitHub <noreply@github.com>2021-09-01 11:59:32 -0400
commitd1f1b46c2cfe612790aebdd54765953951c94e45 (patch)
tree19b13d2d8b152ab7d3d05739c3124f2522aa85e1 /synapse/rest/client/account.py
parentPopulate `rooms.creator` field for easy lookup (#10697) (diff)
downloadsynapse-d1f1b46c2cfe612790aebdd54765953951c94e45.tar.xz
Additional type hints for client REST servlets (part 4) (#10728)
Diffstat (limited to 'synapse/rest/client/account.py')
-rw-r--r--synapse/rest/client/account.py82
1 files changed, 39 insertions, 43 deletions
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index fb5ad2906e..aefaaa8ae8 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -16,9 +16,11 @@
 import logging
 import random
 from http import HTTPStatus
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Tuple
 from urllib.parse import urlparse
 
+from twisted.web.server import Request
+
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     Codes,
@@ -28,15 +30,17 @@ from synapse.api.errors import (
 )
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
-from synapse.http.server import finish_request, respond_with_html
+from synapse.http.server import HttpServer, finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
 from synapse.metrics import threepid_send_requests
 from synapse.push.mailer import Mailer
+from synapse.types import JsonDict
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import assert_valid_client_secret, random_string
 from synapse.util.threepids import check_3pid_allowed, validate_email
@@ -68,7 +72,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
                 template_text=self.config.email_password_reset_template_text,
             )
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
             if self.config.local_threepid_handling_disabled_due_to_email_config:
                 logger.warning(
@@ -159,7 +163,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
 class PasswordRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/password$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
@@ -169,7 +173,7 @@ class PasswordRestServlet(RestServlet):
         self._set_password_handler = hs.get_set_password_handler()
 
     @interactive_auth_handler
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
 
         # we do basic sanity checks here because the auth layer will store these
@@ -190,6 +194,7 @@ class PasswordRestServlet(RestServlet):
         #
         # In the second case, we require a password to confirm their identity.
 
+        requester = None
         if self.auth.has_access_token(request):
             requester = await self.auth.get_user_by_req(request)
             try:
@@ -206,16 +211,15 @@ class PasswordRestServlet(RestServlet):
                 # If a password is available now, hash the provided password and
                 # store it for later.
                 if new_password:
-                    password_hash = await self.auth_handler.hash(new_password)
+                    new_password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
                         e.session_id,
                         UIAuthSessionDataConstants.PASSWORD_HASH,
-                        password_hash,
+                        new_password_hash,
                     )
                 raise
             user_id = requester.user.to_string()
         else:
-            requester = None
             try:
                 result, params, session_id = await self.auth_handler.check_ui_auth(
                     [[LoginType.EMAIL_IDENTITY]],
@@ -230,11 +234,11 @@ class PasswordRestServlet(RestServlet):
                 # If a password is available now, hash the provided password and
                 # store it for later.
                 if new_password:
-                    password_hash = await self.auth_handler.hash(new_password)
+                    new_password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
                         e.session_id,
                         UIAuthSessionDataConstants.PASSWORD_HASH,
-                        password_hash,
+                        new_password_hash,
                     )
                 raise
 
@@ -264,7 +268,7 @@ class PasswordRestServlet(RestServlet):
         # If we have a password in this request, prefer it. Otherwise, use the
         # password hash from an earlier request.
         if new_password:
-            password_hash = await self.auth_handler.hash(new_password)
+            password_hash: Optional[str] = await self.auth_handler.hash(new_password)
         elif session_id is not None:
             password_hash = await self.auth_handler.get_session_data(
                 session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
@@ -288,7 +292,7 @@ class PasswordRestServlet(RestServlet):
 class DeactivateAccountRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/deactivate$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
@@ -296,7 +300,7 @@ class DeactivateAccountRestServlet(RestServlet):
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
 
     @interactive_auth_handler
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
         erase = body.get("erase", False)
         if not isinstance(erase, bool):
@@ -338,7 +342,7 @@ class DeactivateAccountRestServlet(RestServlet):
 class EmailThreepidRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/email/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.config = hs.config
@@ -353,7 +357,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 template_text=self.config.email_add_threepid_template_text,
             )
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
             if self.config.local_threepid_handling_disabled_due_to_email_config:
                 logger.warning(
@@ -449,7 +453,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
         self.store = self.hs.get_datastore()
         self.identity_handler = hs.get_identity_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
         assert_params_in_dict(
             body, ["client_secret", "country", "phone_number", "send_attempt"]
@@ -525,11 +529,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
         "/add_threepid/email/submit_token$", releases=(), unstable=True
     )
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.config = hs.config
         self.clock = hs.get_clock()
@@ -539,7 +539,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
                 self.config.email_add_threepid_template_failure_html
             )
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: Request) -> None:
         if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
             if self.config.local_threepid_handling_disabled_due_to_email_config:
                 logger.warning(
@@ -596,18 +596,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
         "/add_threepid/msisdn/submit_token$", releases=(), unstable=True
     )
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.config = hs.config
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.identity_handler = hs.get_identity_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         if not self.config.account_threepid_delegate_msisdn:
             raise SynapseError(
                 400,
@@ -632,7 +628,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
 class ThreepidRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_identity_handler()
@@ -640,14 +636,14 @@ class ThreepidRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
         self.datastore = self.hs.get_datastore()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         threepids = await self.datastore.user_get_threepids(requester.user.to_string())
 
         return 200, {"threepids": threepids}
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if not self.hs.config.enable_3pid_changes:
             raise SynapseError(
                 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -688,7 +684,7 @@ class ThreepidRestServlet(RestServlet):
 class ThreepidAddRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/add$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_identity_handler()
@@ -696,7 +692,7 @@ class ThreepidAddRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
 
     @interactive_auth_handler
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if not self.hs.config.enable_3pid_changes:
             raise SynapseError(
                 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -738,13 +734,13 @@ class ThreepidAddRestServlet(RestServlet):
 class ThreepidBindRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/bind$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_identity_handler()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
 
         assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
@@ -767,14 +763,14 @@ class ThreepidBindRestServlet(RestServlet):
 class ThreepidUnbindRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/unbind$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_identity_handler()
         self.auth = hs.get_auth()
         self.datastore = self.hs.get_datastore()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         """Unbind the given 3pid from a specific identity server, or identity servers that are
         known to have this 3pid bound
         """
@@ -798,13 +794,13 @@ class ThreepidUnbindRestServlet(RestServlet):
 class ThreepidDeleteRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/delete$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if not self.hs.config.enable_3pid_changes:
             raise SynapseError(
                 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -835,7 +831,7 @@ class ThreepidDeleteRestServlet(RestServlet):
         return 200, {"id_server_unbind_result": id_server_unbind_result}
 
 
-def assert_valid_next_link(hs: "HomeServer", next_link: str):
+def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
     """
     Raises a SynapseError if a given next_link value is invalid
 
@@ -877,11 +873,11 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str):
 class WhoamiRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/whoami$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         response = {"user_id": requester.user.to_string()}
@@ -894,7 +890,7 @@ class WhoamiRestServlet(RestServlet):
         return 200, response
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     EmailPasswordRequestTokenRestServlet(hs).register(http_server)
     PasswordRestServlet(hs).register(http_server)
     DeactivateAccountRestServlet(hs).register(http_server)