summary refs log tree commit diff
path: root/synapse/rest/client
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/rest/client/auth.py24
-rw-r--r--synapse/rest/client/register.py72
2 files changed, 96 insertions, 0 deletions
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 73284e48ec..91800c0278 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet):
         self.registration_handler = hs.get_registration_handler()
         self.recaptcha_template = hs.config.recaptcha_template
         self.terms_template = hs.config.terms_template
+        self.registration_token_template = hs.config.registration_token_template
         self.success_template = hs.config.fallback_success_template
 
     async def on_GET(self, request, stagetype):
@@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet):
             # re-authenticate with their SSO provider.
             html = await self.auth_handler.start_sso_ui_auth(request, session)
 
+        elif stagetype == LoginType.REGISTRATION_TOKEN:
+            html = self.registration_token_template.render(
+                session=session,
+                myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+            )
+
         else:
             raise SynapseError(404, "Unknown auth stage type")
 
@@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet):
             # The SSO fallback workflow should not post here,
             raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
 
+        elif stagetype == LoginType.REGISTRATION_TOKEN:
+            token = parse_string(request, "token", required=True)
+            authdict = {"session": session, "token": token}
+
+            try:
+                await self.auth_handler.add_oob_auth(
+                    LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
+                )
+            except LoginError as e:
+                html = self.registration_token_template.render(
+                    session=session,
+                    myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+                    error=e.msg,
+                )
+            else:
+                html = self.success_template.render()
+
         else:
             raise SynapseError(404, "Unknown auth stage type")
 
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 58b8e8f261..2781a0ea96 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -28,6 +28,7 @@ from synapse.api.errors import (
     ThreepidValidationError,
     UnrecognizedRequestError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config import ConfigError
 from synapse.config.captcha import CaptchaConfig
 from synapse.config.consent import ConsentConfig
@@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
             return 200, {"available": True}
 
 
+class RegistrationTokenValidityRestServlet(RestServlet):
+    """Check the validity of a registration token.
+
+    Example:
+
+        GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
+
+        200 OK
+
+        {
+            "valid": true
+        }
+    """
+
+    PATTERNS = client_patterns(
+        f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
+        releases=(),
+        unstable=True,
+    )
+
+    def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
+        super().__init__()
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.ratelimiter = Ratelimiter(
+            store=self.store,
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
+            burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
+        )
+
+    async def on_GET(self, request):
+        await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
+
+        if not self.hs.config.enable_registration:
+            raise SynapseError(
+                403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+            )
+
+        token = parse_string(request, "token", required=True)
+        valid = await self.store.registration_token_is_valid(token)
+
+        return 200, {"valid": valid}
+
+
 class RegisterRestServlet(RestServlet):
     PATTERNS = client_patterns("/register$")
 
@@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet):
         )
 
         if registered:
+            # Check if a token was used to authenticate registration
+            registration_token = await self.auth_handler.get_session_data(
+                session_id,
+                UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+            )
+            if registration_token:
+                # Increment the `completed` counter for the token
+                await self.store.use_registration_token(registration_token)
+                # Indicate that the token has been successfully used so that
+                # pending is not decremented again when expiring old UIA sessions.
+                await self.store.mark_ui_auth_stage_complete(
+                    session_id,
+                    LoginType.REGISTRATION_TOKEN,
+                    True,
+                )
+
             await self.registration_handler.post_registration_actions(
                 user_id=registered_user_id,
                 auth_result=auth_result,
@@ -868,6 +934,11 @@ def _calculate_registration_flows(
         for flow in flows:
             flow.insert(0, LoginType.RECAPTCHA)
 
+    # Prepend registration token to all flows if we're requiring a token
+    if config.registration_requires_token:
+        for flow in flows:
+            flow.insert(0, LoginType.REGISTRATION_TOKEN)
+
     return flows
 
 
@@ -876,4 +947,5 @@ def register_servlets(hs, http_server):
     MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
     UsernameAvailabilityRestServlet(hs).register(http_server)
     RegistrationSubmitTokenServlet(hs).register(http_server)
+    RegistrationTokenValidityRestServlet(hs).register(http_server)
     RegisterRestServlet(hs).register(http_server)