diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 7b5f49d635..a28acd4041 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -14,7 +14,9 @@
# limitations under the License.
import logging
import random
-from typing import List, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from twisted.web.server import Request
import synapse
import synapse.api.auth
@@ -29,15 +31,13 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
-from synapse.config.captcha import CaptchaConfig
-from synapse.config.consent import ConsentConfig
from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
-from synapse.http.server import finish_request, respond_with_html
+from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -45,6 +45,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.types import JsonDict
@@ -59,17 +60,16 @@ from synapse.util.threepids import (
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/email/requestToken$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
@@ -83,7 +83,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
template_text=self.config.email_registration_template_text,
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -171,16 +171,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/msisdn/requestToken$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(
@@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
"/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -272,7 +264,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.config.email_registration_template_failure_html
)
- async def on_GET(self, request, medium):
+ async def on_GET(self, request: Request, medium: str) -> None:
if medium != "email":
raise SynapseError(
400, "This medium is currently not supported for registration"
@@ -326,11 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_patterns("/register/available")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.registration_handler = hs.get_registration_handler()
@@ -350,7 +338,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
),
)
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
@@ -419,11 +407,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -445,23 +429,21 @@ class RegisterRestServlet(RestServlet):
)
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
client_addr = request.getClientIP()
await self.ratelimiter.ratelimit(None, client_addr, update=False)
- kind = b"user"
- if b"kind" in request.args:
- kind = request.args[b"kind"][0]
+ kind = parse_string(request, "kind", default="user")
- if kind == b"guest":
+ if kind == "guest":
ret = await self._do_guest_registration(body, address=client_addr)
return ret
- elif kind != b"user":
+ elif kind != "user":
raise UnrecognizedRequestError(
- "Do not understand membership kind: %s" % (kind.decode("utf8"),)
+ f"Do not understand membership kind: {kind}",
)
if self._msc2918_enabled:
@@ -749,7 +731,7 @@ class RegisterRestServlet(RestServlet):
async def _do_appservice_registration(
self, username, as_token, body, should_issue_refresh_token: bool = False
- ):
+ ) -> JsonDict:
user_id = await self.registration_handler.appservice_register(
username, as_token
)
@@ -766,7 +748,7 @@ class RegisterRestServlet(RestServlet):
params: JsonDict,
is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False,
- ):
+ ) -> JsonDict:
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -810,7 +792,9 @@ class RegisterRestServlet(RestServlet):
return result
- async def _do_guest_registration(self, params, address=None):
+ async def _do_guest_registration(
+ self, params: JsonDict, address: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
user_id = await self.registration_handler.register_user(
@@ -848,9 +832,7 @@ class RegisterRestServlet(RestServlet):
def _calculate_registration_flows(
- # technically `config` has to provide *all* of these interfaces, not just one
- config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
- auth_handler: AuthHandler,
+ config: HomeServerConfig, auth_handler: AuthHandler
) -> List[List[str]]:
"""Get a suitable flows list for registration
@@ -929,7 +911,7 @@ def _calculate_registration_flows(
return flows
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
|