summary refs log tree commit diff
path: root/synapse/rest/client/v2_alpha/register.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v2_alpha/register.py')
-rw-r--r--synapse/rest/client/v2_alpha/register.py59
1 files changed, 48 insertions, 11 deletions
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index ec5c21fa1f..d32c06c882 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -17,9 +17,9 @@ from twisted.internet import defer
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 
-from ._base import client_v2_patterns, parse_json_dict_from_request
+from ._base import client_v2_patterns
 
 import logging
 import hmac
@@ -73,7 +73,7 @@ class RegisterRestServlet(RestServlet):
             ret = yield self.onEmailTokenRequest(request)
             defer.returnValue(ret)
 
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
 
         # we do basic sanity checks here because the auth layer will store these
         # in sessions. Pull out the username/password provided to us.
@@ -122,10 +122,22 @@ class RegisterRestServlet(RestServlet):
 
         guest_access_token = body.get("guest_access_token", None)
 
+        session_id = self.auth_handler.get_session_id(body)
+        registered_user_id = None
+        if session_id:
+            # if we get a registered user id out of here, it means we previously
+            # registered a user for this session, so we could just return the
+            # user here. We carry on and go through the auth checks though,
+            # for paranoia.
+            registered_user_id = self.auth_handler.get_session_data(
+                session_id, "registered_user_id", None
+            )
+
         if desired_username is not None:
             yield self.registration_handler.check_username(
                 desired_username,
-                guest_access_token=guest_access_token
+                guest_access_token=guest_access_token,
+                assigned_user_id=registered_user_id,
             )
 
         if self.hs.config.enable_registration_captcha:
@@ -139,7 +151,7 @@ class RegisterRestServlet(RestServlet):
                 [LoginType.EMAIL_IDENTITY]
             ]
 
-        authed, result, params = yield self.auth_handler.check_auth(
+        authed, result, params, session_id = yield self.auth_handler.check_auth(
             flows, body, self.hs.get_ip_from_request(request)
         )
 
@@ -147,6 +159,22 @@ class RegisterRestServlet(RestServlet):
             defer.returnValue((401, result))
             return
 
+        if registered_user_id is not None:
+            logger.info(
+                "Already registered user ID %r for this session",
+                registered_user_id
+            )
+            access_token = yield self.auth_handler.issue_access_token(registered_user_id)
+            refresh_token = yield self.auth_handler.issue_refresh_token(
+                registered_user_id
+            )
+            defer.returnValue((200, {
+                "user_id": registered_user_id,
+                "access_token": access_token,
+                "home_server": self.hs.hostname,
+                "refresh_token": refresh_token,
+            }))
+
         # NB: This may be from the auth handler and NOT from the POST
         if 'password' not in params:
             raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
@@ -161,6 +189,12 @@ class RegisterRestServlet(RestServlet):
             guest_access_token=guest_access_token,
         )
 
+        # remember that we've now registered that user account, and with what
+        # user ID (since the user may not have specified)
+        self.auth_handler.set_session_data(
+            session_id, "registered_user_id", user_id
+        )
+
         if result and LoginType.EMAIL_IDENTITY in result:
             threepid = result[LoginType.EMAIL_IDENTITY]
 
@@ -187,7 +221,7 @@ class RegisterRestServlet(RestServlet):
             else:
                 logger.info("bind_email not specified: not binding email")
 
-        result = self._create_registration_details(user_id, token)
+        result = yield self._create_registration_details(user_id, token)
         defer.returnValue((200, result))
 
     def on_OPTIONS(self, _):
@@ -198,7 +232,7 @@ class RegisterRestServlet(RestServlet):
         (user_id, token) = yield self.registration_handler.appservice_register(
             username, as_token
         )
-        defer.returnValue(self._create_registration_details(user_id, token))
+        defer.returnValue((yield self._create_registration_details(user_id, token)))
 
     @defer.inlineCallbacks
     def _do_shared_secret_registration(self, username, password, mac):
@@ -225,18 +259,21 @@ class RegisterRestServlet(RestServlet):
         (user_id, token) = yield self.registration_handler.register(
             localpart=username, password=password
         )
-        defer.returnValue(self._create_registration_details(user_id, token))
+        defer.returnValue((yield self._create_registration_details(user_id, token)))
 
+    @defer.inlineCallbacks
     def _create_registration_details(self, user_id, token):
-        return {
+        refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
+        defer.returnValue({
             "user_id": user_id,
             "access_token": token,
             "home_server": self.hs.hostname,
-        }
+            "refresh_token": refresh_token,
+        })
 
     @defer.inlineCallbacks
     def onEmailTokenRequest(self, request):
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
 
         required = ['id_server', 'client_secret', 'email', 'send_attempt']
         absent = []