diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index d3b6803e65..1bc737bad0 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
import pkg_resources
@@ -20,8 +21,7 @@ from twisted.web.http import Request
from twisted.web.resource import Resource
from twisted.web.static import File
-from synapse.api.errors import SynapseError
-from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
@@ -61,12 +61,10 @@ class AvailabilityCheckResource(DirectServeJsonResource):
async def _async_render_GET(self, request: Request):
localpart = parse_string(request, "username", required=True)
- session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
- if not session_id:
- raise SynapseError(code=400, msg="missing session_id")
+ session_id = get_username_mapping_session_cookie_from_request(request)
is_available = await self._sso_handler.check_username_availability(
- localpart, session_id.decode("ascii", errors="replace")
+ localpart, session_id
)
return 200, {"available": is_available}
@@ -79,10 +77,8 @@ class SubmitResource(DirectServeHtmlResource):
async def _async_render_POST(self, request: SynapseRequest):
localpart = parse_string(request, "username", required=True)
- session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
- if not session_id:
- raise SynapseError(code=400, msg="missing session_id")
+ session_id = get_username_mapping_session_cookie_from_request(request)
await self._sso_handler.handle_submit_username_request(
- request, localpart, session_id.decode("ascii", errors="replace")
+ request, localpart, session_id
)
diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py
new file mode 100644
index 0000000000..dfefeb7796
--- /dev/null
+++ b/synapse/rest/synapse/client/sso_register.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SsoRegisterResource(DirectServeHtmlResource):
+ """A resource which completes SSO registration
+
+ This resource gets mounted at /_synapse/client/sso_register, and is shown
+ after we collect username and/or consent for a new SSO user. It (finally) registers
+ the user, and confirms redirect to the client
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ async def _async_render_GET(self, request: Request) -> None:
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ except SynapseError as e:
+ logger.warning("Error fetching session cookie: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+ await self._sso_handler.register_sso_user(request, session_id)
|