diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 6ea1b50a62..df8cc4ac7a 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -15,11 +15,14 @@
import logging
from typing import TYPE_CHECKING
+from twisted.web.server import Request
+
from synapse.api.constants import LoginType
-from synapse.api.errors import SynapseError
+from synapse.api.errors import LoginError, SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
-from synapse.http.server import respond_with_html
+from synapse.http.server import HttpServer, respond_with_html
from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
from ._base import client_patterns
@@ -46,9 +49,10 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
+ self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template
- async def on_GET(self, request, stagetype):
+ async def on_GET(self, request: SynapseRequest, stagetype: str) -> None:
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
@@ -74,6 +78,12 @@ class AuthRestServlet(RestServlet):
# re-authenticate with their SSO provider.
html = await self.auth_handler.start_sso_ui_auth(request, session)
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ )
+
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -81,7 +91,7 @@ class AuthRestServlet(RestServlet):
respond_with_html(request, 200, html)
return None
- async def on_POST(self, request, stagetype):
+ async def on_POST(self, request: Request, stagetype: str) -> None:
session = parse_string(request, "session")
if not session:
@@ -95,29 +105,32 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
- success = await self.auth_handler.add_oob_auth(
- LoginType.RECAPTCHA, authdict, request.getClientIP()
- )
-
- if success:
- html = self.success_template.render()
- else:
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.RECAPTCHA, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ # Authentication failed, let user try again
html = self.recaptcha_template.render(
session=session,
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
sitekey=self.hs.config.recaptcha_public_key,
+ error=e.msg,
)
+ else:
+ # No LoginError was raised, so authentication was successful
+ html = self.success_template.render()
+
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
- success = await self.auth_handler.add_oob_auth(
- LoginType.TERMS, authdict, request.getClientIP()
- )
-
- if success:
- html = self.success_template.render()
- else:
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.TERMS, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ # Authentication failed, let user try again
html = self.terms_template.render(
session=session,
terms_url="%s_matrix/consent?v=%s"
@@ -127,10 +140,33 @@ class AuthRestServlet(RestServlet):
),
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
+ error=e.msg,
)
+ else:
+ # No LoginError was raised, so authentication was successful
+ html = self.success_template.render()
+
elif stagetype == LoginType.SSO:
# The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
+
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ token = parse_string(request, "token", required=True)
+ authdict = {"session": session, "token": token}
+
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ error=e.msg,
+ )
+ else:
+ html = self.success_template.render()
+
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -139,5 +175,5 @@ class AuthRestServlet(RestServlet):
return None
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AuthRestServlet(hs).register(http_server)
|