summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Baker <dbkr@users.noreply.github.com>2016-03-17 14:34:08 +0000
committerDavid Baker <dbkr@users.noreply.github.com>2016-03-17 14:34:08 +0000
commit384ee6eafba3f6b76652e2ffb336c66a16aa849d (patch)
tree48b2efc3a703a21e66f449c37e4aaef36567b9f6
parentMerge pull request #651 from matrix-org/markjh/unused_II (diff)
parentremove debug logging (diff)
downloadsynapse-384ee6eafba3f6b76652e2ffb336c66a16aa849d.tar.xz
Merge pull request #650 from matrix-org/dbkr/register_idempotent_with_username
Make registration idempotent, part 2
-rw-r--r--synapse/handlers/auth.py14
-rw-r--r--synapse/handlers/register.py12
-rw-r--r--synapse/rest/client/v2_alpha/register.py18
3 files changed, 38 insertions, 6 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d7233cd0d6..82d458b424 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -160,6 +160,20 @@ class AuthHandler(BaseHandler):
             defer.returnValue(True)
         defer.returnValue(False)
 
+    def get_session_id(self, clientdict):
+        """
+        Gets the session ID for a client given the client dictionary
+        :param clientdict: The dictionary sent by the client in the request
+        :return: The string session ID the client sent. If the client did not
+                 send a session ID, returns None.
+        """
+        sid = None
+        if clientdict and 'auth' in clientdict:
+            authdict = clientdict['auth']
+            if 'session' in authdict:
+                sid = authdict['session']
+        return sid
+
     def set_session_data(self, session_id, key, value):
         """
         Store a key-value pair into the sessions data associated with this
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 6ffb8c0da6..f287ee247b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -47,7 +47,8 @@ class RegistrationHandler(BaseHandler):
         self._next_generated_user_id = None
 
     @defer.inlineCallbacks
-    def check_username(self, localpart, guest_access_token=None):
+    def check_username(self, localpart, guest_access_token=None,
+                       assigned_user_id=None):
         yield run_on_reactor()
 
         if urllib.quote(localpart.encode('utf-8')) != localpart:
@@ -60,6 +61,15 @@ class RegistrationHandler(BaseHandler):
         user = UserID(localpart, self.hs.hostname)
         user_id = user.to_string()
 
+        if assigned_user_id:
+            if user_id == assigned_user_id:
+                return
+            else:
+                raise SynapseError(
+                    400,
+                    "A different user ID has already been registered for this session",
+                )
+
         yield self.check_user_id_not_appservice_exclusive(user_id)
 
         users = yield self.store.get_users_by_id_case_insensitive(user_id)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index c440430e25..d32c06c882 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -122,10 +122,22 @@ class RegisterRestServlet(RestServlet):
 
         guest_access_token = body.get("guest_access_token", None)
 
+        session_id = self.auth_handler.get_session_id(body)
+        registered_user_id = None
+        if session_id:
+            # if we get a registered user id out of here, it means we previously
+            # registered a user for this session, so we could just return the
+            # user here. We carry on and go through the auth checks though,
+            # for paranoia.
+            registered_user_id = self.auth_handler.get_session_data(
+                session_id, "registered_user_id", None
+            )
+
         if desired_username is not None:
             yield self.registration_handler.check_username(
                 desired_username,
-                guest_access_token=guest_access_token
+                guest_access_token=guest_access_token,
+                assigned_user_id=registered_user_id,
             )
 
         if self.hs.config.enable_registration_captcha:
@@ -147,10 +159,6 @@ class RegisterRestServlet(RestServlet):
             defer.returnValue((401, result))
             return
 
-        # have we already registered a user for this session
-        registered_user_id = self.auth_handler.get_session_data(
-            session_id, "registered_user_id", None
-        )
         if registered_user_id is not None:
             logger.info(
                 "Already registered user ID %r for this session",