summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2016-07-19 13:12:22 +0100
committerRichard van der Hoff <richard@matrix.org>2016-07-19 13:12:22 +0100
commit0da0d0a29d807c481152b1580acbbe36f24cf771 (patch)
tree8a02a5eff71ad6410b726ee2d41233590b6b907d
parentMerge pull request #930 from matrix-org/markjh/handlers (diff)
downloadsynapse-0da0d0a29d807c481152b1580acbbe36f24cf771.tar.xz
rest/client/v2_alpha/register.py: Refactor flow somewhat.
This is meant to be an *almost* non-functional change, with the exception that
it fixes what looks a lot like a bug in that it only calls
`auth_handler.add_threepid` and `add_pusher` once instead of three times.

The idea is to move the generation of the `access_token` out of
`registration_handler.register`, because `access_token`s now require a
device_id, and we only want to generate a device_id once registration has been
successful.
-rw-r--r--synapse/rest/client/v2_alpha/register.py177
-rw-r--r--tests/rest/client/v2_alpha/test_register.py3
2 files changed, 104 insertions, 76 deletions
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index e8d34b06b0..707bde0f34 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -199,92 +199,55 @@ class RegisterRestServlet(RestServlet):
                 "Already registered user ID %r for this session",
                 registered_user_id
             )
-            access_token = yield self.auth_handler.issue_access_token(registered_user_id)
-            refresh_token = yield self.auth_handler.issue_refresh_token(
-                registered_user_id
+            # don't re-register the email address
+            add_email = False
+        else:
+            # NB: This may be from the auth handler and NOT from the POST
+            if 'password' not in params:
+                raise SynapseError(400, "Missing password.",
+                                   Codes.MISSING_PARAM)
+
+            desired_username = params.get("username", None)
+            new_password = params.get("password", None)
+            guest_access_token = params.get("guest_access_token", None)
+
+            (registered_user_id, _) = yield self.registration_handler.register(
+                localpart=desired_username,
+                password=new_password,
+                guest_access_token=guest_access_token,
+                generate_token=False,
             )
-            defer.returnValue((200, {
-                "user_id": registered_user_id,
-                "access_token": access_token,
-                "home_server": self.hs.hostname,
-                "refresh_token": refresh_token,
-            }))
 
-        # NB: This may be from the auth handler and NOT from the POST
-        if 'password' not in params:
-            raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
+            # remember that we've now registered that user account, and with
+            #  what user ID (since the user may not have specified)
+            self.auth_handler.set_session_data(
+                session_id, "registered_user_id", registered_user_id
+            )
 
-        desired_username = params.get("username", None)
-        new_password = params.get("password", None)
-        guest_access_token = params.get("guest_access_token", None)
+            add_email = True
 
-        (user_id, token) = yield self.registration_handler.register(
-            localpart=desired_username,
-            password=new_password,
-            guest_access_token=guest_access_token,
+        access_token = yield self.auth_handler.issue_access_token(
+            registered_user_id
         )
 
-        # remember that we've now registered that user account, and with what
-        # user ID (since the user may not have specified)
-        self.auth_handler.set_session_data(
-            session_id, "registered_user_id", user_id
-        )
-
-        if result and LoginType.EMAIL_IDENTITY in result:
+        if add_email and result and LoginType.EMAIL_IDENTITY in result:
             threepid = result[LoginType.EMAIL_IDENTITY]
-
-            for reqd in ['medium', 'address', 'validated_at']:
-                if reqd not in threepid:
-                    logger.info("Can't add incomplete 3pid")
-                else:
-                    yield self.auth_handler.add_threepid(
-                        user_id,
-                        threepid['medium'],
-                        threepid['address'],
-                        threepid['validated_at'],
-                    )
-
-                    # And we add an email pusher for them by default, but only
-                    # if email notifications are enabled (so people don't start
-                    # getting mail spam where they weren't before if email
-                    # notifs are set up on a home server)
-                    if (
-                        self.hs.config.email_enable_notifs and
-                        self.hs.config.email_notif_for_new_users
-                    ):
-                        # Pull the ID of the access token back out of the db
-                        # It would really make more sense for this to be passed
-                        # up when the access token is saved, but that's quite an
-                        # invasive change I'd rather do separately.
-                        user_tuple = yield self.store.get_user_by_access_token(
-                            token
-                        )
-
-                        yield self.hs.get_pusherpool().add_pusher(
-                            user_id=user_id,
-                            access_token=user_tuple["token_id"],
-                            kind="email",
-                            app_id="m.email",
-                            app_display_name="Email Notifications",
-                            device_display_name=threepid["address"],
-                            pushkey=threepid["address"],
-                            lang=None,  # We don't know a user's language here
-                            data={},
-                        )
-
-            if 'bind_email' in params and params['bind_email']:
+            reqd = ('medium', 'address', 'validated_at')
+            if all(x in threepid for x in reqd):
+                yield self._register_email_threepid(
+                    registered_user_id, threepid, access_token
+                )
+                # XXX why is bind_email not protected by this?
+            else:
+                logger.info("Can't add incomplete 3pid")
+            if params.get("bind_email"):
                 logger.info("bind_email specified: binding")
-
-                emailThreepid = result[LoginType.EMAIL_IDENTITY]
-                threepid_creds = emailThreepid['threepid_creds']
-                logger.debug("Binding emails %s to %s" % (
-                    emailThreepid, user_id
-                ))
-                yield self.identity_handler.bind_threepid(threepid_creds, user_id)
+                yield self._bind_email(registered_user_id, threepid)
             else:
                 logger.info("bind_email not specified: not binding email")
 
-        result = yield self._create_registration_details(user_id, token)
+        result = yield self._create_registration_details(registered_user_id,
+                                                         access_token)
         defer.returnValue((200, result))
 
     def on_OPTIONS(self, _):
@@ -325,6 +288,70 @@ class RegisterRestServlet(RestServlet):
         defer.returnValue((yield self._create_registration_details(user_id, token)))
 
     @defer.inlineCallbacks
+    def _register_email_threepid(self, user_id, threepid, token):
+        """Add an email address as a 3pid identifier
+
+        Also adds an email pusher for the email address, if configured in the
+        HS config
+
+        Args:
+            user_id (str): id of user
+            threepid (object): m.login.email.identity auth response
+            token (str): access_token for the user
+        Returns:
+            defer.Deferred:
+        """
+        yield self.auth_handler.add_threepid(
+            user_id,
+            threepid['medium'],
+            threepid['address'],
+            threepid['validated_at'],
+        )
+
+        # And we add an email pusher for them by default, but only
+        # if email notifications are enabled (so people don't start
+        # getting mail spam where they weren't before if email
+        # notifs are set up on a home server)
+        if (self.hs.config.email_enable_notifs and
+                self.hs.config.email_notif_for_new_users):
+            # Pull the ID of the access token back out of the db
+            # It would really make more sense for this to be passed
+            # up when the access token is saved, but that's quite an
+            # invasive change I'd rather do separately.
+            user_tuple = yield self.store.get_user_by_access_token(
+                token
+            )
+            token_id = user_tuple["token_id"]
+
+            yield self.hs.get_pusherpool().add_pusher(
+                user_id=user_id,
+                access_token=token_id,
+                kind="email",
+                app_id="m.email",
+                app_display_name="Email Notifications",
+                device_display_name=threepid["address"],
+                pushkey=threepid["address"],
+                lang=None,  # We don't know a user's language here
+                data={},
+            )
+        defer.returnValue()
+
+    def _bind_email(self, user_id, email_threepid):
+        """Bind emails to the given user_id on the identity server
+
+        Args:
+            user_id (str): user id to bind the emails to
+            email_threepid (object): m.login.email.identity auth response
+        Returns:
+            defer.Deferred:
+        """
+        threepid_creds = email_threepid['threepid_creds']
+        logger.debug("Binding emails %s to %s" % (
+            email_threepid, user_id
+        ))
+        return self.identity_handler.bind_threepid(threepid_creds, user_id)
+
+    @defer.inlineCallbacks
     def _create_registration_details(self, user_id, token):
         refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
         defer.returnValue({
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index cda0a2b27c..9a4215fef7 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -114,7 +114,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
             "username": "kermit",
             "password": "monkey"
         }, None)
-        self.registration_handler.register = Mock(return_value=(user_id, token))
+        self.registration_handler.register = Mock(return_value=(user_id, None))
+        self.auth_handler.issue_access_token = Mock(return_value=token)
 
         (code, result) = yield self.servlet.on_POST(self.request)
         self.assertEquals(code, 200)