summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-06 08:09:55 -0400
committerGitHub <noreply@github.com>2020-08-06 08:09:55 -0400
commit66f24449dd614b23ea4c572d8d613efeb129e4a2 (patch)
tree73b60577aeee4054d5578e79af1cf13873355644
parentFixup worker doc (again) (#8000) (diff)
downloadsynapse-66f24449dd614b23ea4c572d8d613efeb129e4a2.tar.xz
Improve performance of the register endpoint (#8009)
Diffstat (limited to '')
-rw-r--r--changelog.d/8009.misc1
-rw-r--r--synapse/api/errors.py4
-rw-r--r--synapse/handlers/auth.py19
-rw-r--r--synapse/rest/client/v2_alpha/account.py86
-rw-r--r--synapse/rest/client/v2_alpha/register.py108
-rw-r--r--tests/rest/client/v2_alpha/test_register.py2
6 files changed, 146 insertions, 74 deletions
diff --git a/changelog.d/8009.misc b/changelog.d/8009.misc
new file mode 100644
index 0000000000..3d58a11313
--- /dev/null
+++ b/changelog.d/8009.misc
@@ -0,0 +1 @@
+Improve the performance of the register endpoint.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index b3bab1aa52..6e40630ab6 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -238,14 +238,16 @@ class InteractiveAuthIncompleteError(Exception):
     (This indicates we should return a 401 with 'result' as the body)
 
     Attributes:
+        session_id: The ID of the ongoing interactive auth session.
         result: the server response to the request, which should be
             passed back to the client
     """
 
-    def __init__(self, result: "JsonDict"):
+    def __init__(self, session_id: str, result: "JsonDict"):
         super(InteractiveAuthIncompleteError, self).__init__(
             "Interactive auth not yet complete"
         )
+        self.session_id = session_id
         self.result = result
 
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7d921c21a..c24e7bafe0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -162,7 +162,7 @@ class AuthHandler(BaseHandler):
         request_body: Dict[str, Any],
         clientip: str,
         description: str,
-    ) -> dict:
+    ) -> Tuple[dict, str]:
         """
         Checks that the user is who they claim to be, via a UI auth.
 
@@ -183,9 +183,14 @@ class AuthHandler(BaseHandler):
                          describes the operation happening on their account.
 
         Returns:
-            The parameters for this request (which may
+            A tuple of (params, session_id).
+
+                'params' contains the parameters for this request (which may
                 have been given only in a previous call).
 
+                'session_id' is the ID of this session, either passed in by the
+                client or assigned by this call
+
         Raises:
             InteractiveAuthIncompleteError if the client has not yet completed
                 any of the permitted login flows
@@ -207,7 +212,7 @@ class AuthHandler(BaseHandler):
         flows = [[login_type] for login_type in self._supported_ui_auth_types]
 
         try:
-            result, params, _ = await self.check_auth(
+            result, params, session_id = await self.check_ui_auth(
                 flows, request, request_body, clientip, description
             )
         except LoginError:
@@ -230,7 +235,7 @@ class AuthHandler(BaseHandler):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Invalid auth")
 
-        return params
+        return params, session_id
 
     def get_enabled_auth_types(self):
         """Return the enabled user-interactive authentication types
@@ -240,7 +245,7 @@ class AuthHandler(BaseHandler):
         """
         return self.checkers.keys()
 
-    async def check_auth(
+    async def check_ui_auth(
         self,
         flows: List[List[str]],
         request: SynapseRequest,
@@ -363,7 +368,7 @@ class AuthHandler(BaseHandler):
 
         if not authdict:
             raise InteractiveAuthIncompleteError(
-                self._auth_dict_for_flows(flows, session.session_id)
+                session.session_id, self._auth_dict_for_flows(flows, session.session_id)
             )
 
         # check auth type currently being presented
@@ -410,7 +415,7 @@ class AuthHandler(BaseHandler):
         ret = self._auth_dict_for_flows(flows, session.session_id)
         ret["completed"] = list(creds)
         ret.update(errordict)
-        raise InteractiveAuthIncompleteError(ret)
+        raise InteractiveAuthIncompleteError(session.session_id, ret)
 
     async def add_oob_auth(
         self, stagetype: str, authdict: Dict[str, Any], clientip: str
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 3767a809a4..fead85074b 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -18,7 +18,12 @@ import logging
 from http import HTTPStatus
 
 from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
+from synapse.api.errors import (
+    Codes,
+    InteractiveAuthIncompleteError,
+    SynapseError,
+    ThreepidValidationError,
+)
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
@@ -239,18 +244,12 @@ class PasswordRestServlet(RestServlet):
 
         # we do basic sanity checks here because the auth layer will store these
         # in sessions. Pull out the new password provided to us.
-        if "new_password" in body:
-            new_password = body.pop("new_password")
+        new_password = body.pop("new_password", None)
+        if new_password is not None:
             if not isinstance(new_password, str) or len(new_password) > 512:
                 raise SynapseError(400, "Invalid password")
             self.password_policy_handler.validate_password(new_password)
 
-            # If the password is valid, hash it and store it back on the body.
-            # This ensures that only the hashed password is handled everywhere.
-            if "new_password_hash" in body:
-                raise SynapseError(400, "Unexpected property: new_password_hash")
-            body["new_password_hash"] = await self.auth_handler.hash(new_password)
-
         # there are two possibilities here. Either the user does not have an
         # access token, and needs to do a password reset; or they have one and
         # need to validate their identity.
@@ -263,23 +262,49 @@ class PasswordRestServlet(RestServlet):
 
         if self.auth.has_access_token(request):
             requester = await self.auth.get_user_by_req(request)
-            params = await self.auth_handler.validate_user_via_ui_auth(
-                requester,
-                request,
-                body,
-                self.hs.get_ip_from_request(request),
-                "modify your account password",
-            )
+            try:
+                params, session_id = await self.auth_handler.validate_user_via_ui_auth(
+                    requester,
+                    request,
+                    body,
+                    self.hs.get_ip_from_request(request),
+                    "modify your account password",
+                )
+            except InteractiveAuthIncompleteError as e:
+                # The user needs to provide more steps to complete auth, but
+                # they're not required to provide the password again.
+                #
+                # If a password is available now, hash the provided password and
+                # store it for later.
+                if new_password:
+                    password_hash = await self.auth_handler.hash(new_password)
+                    await self.auth_handler.set_session_data(
+                        e.session_id, "password_hash", password_hash
+                    )
+                raise
             user_id = requester.user.to_string()
         else:
             requester = None
-            result, params, _ = await self.auth_handler.check_auth(
-                [[LoginType.EMAIL_IDENTITY]],
-                request,
-                body,
-                self.hs.get_ip_from_request(request),
-                "modify your account password",
-            )
+            try:
+                result, params, session_id = await self.auth_handler.check_ui_auth(
+                    [[LoginType.EMAIL_IDENTITY]],
+                    request,
+                    body,
+                    self.hs.get_ip_from_request(request),
+                    "modify your account password",
+                )
+            except InteractiveAuthIncompleteError as e:
+                # The user needs to provide more steps to complete auth, but
+                # they're not required to provide the password again.
+                #
+                # If a password is available now, hash the provided password and
+                # store it for later.
+                if new_password:
+                    password_hash = await self.auth_handler.hash(new_password)
+                    await self.auth_handler.set_session_data(
+                        e.session_id, "password_hash", password_hash
+                    )
+                raise
 
             if LoginType.EMAIL_IDENTITY in result:
                 threepid = result[LoginType.EMAIL_IDENTITY]
@@ -304,12 +329,21 @@ class PasswordRestServlet(RestServlet):
                 logger.error("Auth succeeded but no known type! %r", result.keys())
                 raise SynapseError(500, "", Codes.UNKNOWN)
 
-        assert_params_in_dict(params, ["new_password_hash"])
-        new_password_hash = params["new_password_hash"]
+        # If we have a password in this request, prefer it. Otherwise, there
+        # must be a password hash from an earlier request.
+        if new_password:
+            password_hash = await self.auth_handler.hash(new_password)
+        else:
+            password_hash = await self.auth_handler.get_session_data(
+                session_id, "password_hash", None
+            )
+        if not password_hash:
+            raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
+
         logout_devices = params.get("logout_devices", True)
 
         await self._set_password_handler.set_password(
-            user_id, new_password_hash, logout_devices, requester
+            user_id, password_hash, logout_devices, requester
         )
 
         return 200, {}
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 370742ce59..a4c079196d 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -24,6 +24,7 @@ import synapse.types
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     Codes,
+    InteractiveAuthIncompleteError,
     SynapseError,
     ThreepidValidationError,
     UnrecognizedRequestError,
@@ -387,6 +388,7 @@ class RegisterRestServlet(RestServlet):
         self.ratelimiter = hs.get_registration_ratelimiter()
         self.password_policy_handler = hs.get_password_policy_handler()
         self.clock = hs.get_clock()
+        self._registration_enabled = self.hs.config.enable_registration
 
         self._registration_flows = _calculate_registration_flows(
             hs.config, self.auth_handler
@@ -412,20 +414,8 @@ class RegisterRestServlet(RestServlet):
                 "Do not understand membership kind: %s" % (kind.decode("utf8"),)
             )
 
-        # we do basic sanity checks here because the auth layer will store these
-        # in sessions. Pull out the username/password provided to us.
-        if "password" in body:
-            password = body.pop("password")
-            if not isinstance(password, str) or len(password) > 512:
-                raise SynapseError(400, "Invalid password")
-            self.password_policy_handler.validate_password(password)
-
-            # If the password is valid, hash it and store it back on the body.
-            # This ensures that only the hashed password is handled everywhere.
-            if "password_hash" in body:
-                raise SynapseError(400, "Unexpected property: password_hash")
-            body["password_hash"] = await self.auth_handler.hash(password)
-
+        # Pull out the provided username and do basic sanity checks early since
+        # the auth layer will store these in sessions.
         desired_username = None
         if "username" in body:
             if not isinstance(body["username"], str) or len(body["username"]) > 512:
@@ -459,22 +449,35 @@ class RegisterRestServlet(RestServlet):
                 )
             return 200, result  # we throw for non 200 responses
 
-        # for regular registration, downcase the provided username before
-        # attempting to register it. This should mean
-        # that people who try to register with upper-case in their usernames
-        # don't get a nasty surprise. (Note that we treat username
-        # case-insenstively in login, so they are free to carry on imagining
-        # that their username is CrAzYh4cKeR if that keeps them happy)
-        if desired_username is not None:
-            desired_username = desired_username.lower()
-
         # == Normal User Registration == (everyone else)
-        if not self.hs.config.enable_registration:
+        if not self._registration_enabled:
             raise SynapseError(403, "Registration has been disabled")
 
+        # For regular registration, convert the provided username to lowercase
+        # before attempting to register it. This should mean that people who try
+        # to register with upper-case in their usernames don't get a nasty surprise.
+        #
+        # Note that we treat usernames case-insensitively in login, so they are
+        # free to carry on imagining that their username is CrAzYh4cKeR if that
+        # keeps them happy.
+        if desired_username is not None:
+            desired_username = desired_username.lower()
+
+        # Check if this account is upgrading from a guest account.
         guest_access_token = body.get("guest_access_token", None)
 
-        if "initial_device_display_name" in body and "password_hash" not in body:
+        # Pull out the provided password and do basic sanity checks early.
+        #
+        # Note that we remove the password from the body since the auth layer
+        # will store the body in the session and we don't want a plaintext
+        # password store there.
+        password = body.pop("password", None)
+        if password is not None:
+            if not isinstance(password, str) or len(password) > 512:
+                raise SynapseError(400, "Invalid password")
+            self.password_policy_handler.validate_password(password)
+
+        if "initial_device_display_name" in body and password is None:
             # ignore 'initial_device_display_name' if sent without
             # a password to work around a client bug where it sent
             # the 'initial_device_display_name' param alone, wiping out
@@ -484,6 +487,7 @@ class RegisterRestServlet(RestServlet):
 
         session_id = self.auth_handler.get_session_id(body)
         registered_user_id = None
+        password_hash = None
         if session_id:
             # if we get a registered user id out of here, it means we previously
             # registered a user for this session, so we could just return the
@@ -492,7 +496,12 @@ class RegisterRestServlet(RestServlet):
             registered_user_id = await self.auth_handler.get_session_data(
                 session_id, "registered_user_id", None
             )
+            # Extract the previously-hashed password from the session.
+            password_hash = await self.auth_handler.get_session_data(
+                session_id, "password_hash", None
+            )
 
+        # Ensure that the username is valid.
         if desired_username is not None:
             await self.registration_handler.check_username(
                 desired_username,
@@ -500,20 +509,38 @@ class RegisterRestServlet(RestServlet):
                 assigned_user_id=registered_user_id,
             )
 
-        auth_result, params, session_id = await self.auth_handler.check_auth(
-            self._registration_flows,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "register a new account",
-        )
+        # Check if the user-interactive authentication flows are complete, if
+        # not this will raise a user-interactive auth error.
+        try:
+            auth_result, params, session_id = await self.auth_handler.check_ui_auth(
+                self._registration_flows,
+                request,
+                body,
+                self.hs.get_ip_from_request(request),
+                "register a new account",
+            )
+        except InteractiveAuthIncompleteError as e:
+            # The user needs to provide more steps to complete auth.
+            #
+            # Hash the password and store it with the session since the client
+            # is not required to provide the password again.
+            #
+            # If a password hash was previously stored we will not attempt to
+            # re-hash and store it for efficiency. This assumes the password
+            # does not change throughout the authentication flow, but this
+            # should be fine since the data is meant to be consistent.
+            if not password_hash and password:
+                password_hash = await self.auth_handler.hash(password)
+                await self.auth_handler.set_session_data(
+                    e.session_id, "password_hash", password_hash
+                )
+            raise
 
         # Check that we're not trying to register a denied 3pid.
         #
         # the user-facing checks will probably already have happened in
         # /register/email/requestToken when we requested a 3pid, but that's not
         # guaranteed.
-
         if auth_result:
             for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
                 if login_type in auth_result:
@@ -535,12 +562,15 @@ class RegisterRestServlet(RestServlet):
             # don't re-register the threepids
             registered = False
         else:
-            # NB: This may be from the auth handler and NOT from the POST
-            assert_params_in_dict(params, ["password_hash"])
+            # If we have a password in this request, prefer it. Otherwise, there
+            # might be a password hash from an earlier request.
+            if password:
+                password_hash = await self.auth_handler.hash(password)
+            if not password_hash:
+                raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
 
             desired_username = params.get("username", None)
             guest_access_token = params.get("guest_access_token", None)
-            new_password_hash = params.get("password_hash", None)
 
             if desired_username is not None:
                 desired_username = desired_username.lower()
@@ -582,7 +612,7 @@ class RegisterRestServlet(RestServlet):
 
             registered_user_id = await self.registration_handler.register_user(
                 localpart=desired_username,
-                password_hash=new_password_hash,
+                password_hash=password_hash,
                 guest_access_token=guest_access_token,
                 threepid=threepid,
                 address=client_addr,
@@ -595,8 +625,8 @@ class RegisterRestServlet(RestServlet):
                 ):
                     await self.store.upsert_monthly_active_user(registered_user_id)
 
-            # remember that we've now registered that user account, and with
-            #  what user ID (since the user may not have specified)
+            # Remember that the user account has been registered (and the user
+            # ID it was registered with, since it might not have been specified).
             await self.auth_handler.set_session_data(
                 session_id, "registered_user_id", registered_user_id
             )
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 7deaf5b24a..53a43038f0 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -116,8 +116,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
         self.assertDictContainsSubset(det_data, channel.json_body)
 
+    @override_config({"enable_registration": False})
     def test_POST_disabled_registration(self):
-        self.hs.config.enable_registration = False
         request_data = json.dumps({"username": "kermit", "password": "monkey"})
         self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)