diff --git a/changelog.d/8009.misc b/changelog.d/8009.misc
new file mode 100644
index 0000000000..3d58a11313
--- /dev/null
+++ b/changelog.d/8009.misc
@@ -0,0 +1 @@
+Improve the performance of the register endpoint.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index fcfdf8bd2b..a3b0c0a3e4 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -239,14 +239,16 @@ class InteractiveAuthIncompleteError(Exception):
(This indicates we should return a 401 with 'result' as the body)
Attributes:
+ session_id: The ID of the ongoing interactive auth session.
result: the server response to the request, which should be
passed back to the client
"""
- def __init__(self, result: "JsonDict"):
+ def __init__(self, session_id: str, result: "JsonDict"):
super(InteractiveAuthIncompleteError, self).__init__(
"Interactive auth not yet complete"
)
+ self.session_id = session_id
self.result = result
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7d921c21a..c24e7bafe0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -162,7 +162,7 @@ class AuthHandler(BaseHandler):
request_body: Dict[str, Any],
clientip: str,
description: str,
- ) -> dict:
+ ) -> Tuple[dict, str]:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -183,9 +183,14 @@ class AuthHandler(BaseHandler):
describes the operation happening on their account.
Returns:
- The parameters for this request (which may
+ A tuple of (params, session_id).
+
+ 'params' contains the parameters for this request (which may
have been given only in a previous call).
+ 'session_id' is the ID of this session, either passed in by the
+ client or assigned by this call
+
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
any of the permitted login flows
@@ -207,7 +212,7 @@ class AuthHandler(BaseHandler):
flows = [[login_type] for login_type in self._supported_ui_auth_types]
try:
- result, params, _ = await self.check_auth(
+ result, params, session_id = await self.check_ui_auth(
flows, request, request_body, clientip, description
)
except LoginError:
@@ -230,7 +235,7 @@ class AuthHandler(BaseHandler):
if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth")
- return params
+ return params, session_id
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -240,7 +245,7 @@ class AuthHandler(BaseHandler):
"""
return self.checkers.keys()
- async def check_auth(
+ async def check_ui_auth(
self,
flows: List[List[str]],
request: SynapseRequest,
@@ -363,7 +368,7 @@ class AuthHandler(BaseHandler):
if not authdict:
raise InteractiveAuthIncompleteError(
- self._auth_dict_for_flows(flows, session.session_id)
+ session.session_id, self._auth_dict_for_flows(flows, session.session_id)
)
# check auth type currently being presented
@@ -410,7 +415,7 @@ class AuthHandler(BaseHandler):
ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
- raise InteractiveAuthIncompleteError(ret)
+ raise InteractiveAuthIncompleteError(session.session_id, ret)
async def add_oob_auth(
self, stagetype: str, authdict: Dict[str, Any], clientip: str
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index d4b1ee1e8c..c8e793ade2 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -26,7 +26,12 @@ if TYPE_CHECKING:
from twisted.internet import defer
from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
+from synapse.api.errors import (
+ Codes,
+ InteractiveAuthIncompleteError,
+ SynapseError,
+ ThreepidValidationError,
+)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
@@ -253,18 +258,12 @@ class PasswordRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the new password provided to us.
- if "new_password" in body:
- new_password = body.pop("new_password")
+ new_password = body.pop("new_password", None)
+ if new_password is not None:
if not isinstance(new_password, str) or len(new_password) > 512:
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(new_password)
- # If the password is valid, hash it and store it back on the body.
- # This ensures that only the hashed password is handled everywhere.
- if "new_password_hash" in body:
- raise SynapseError(400, "Unexpected property: new_password_hash")
- body["new_password_hash"] = await self.auth_handler.hash(new_password)
-
# there are two possibilities here. Either the user does not have an
# access token, and needs to do a password reset; or they have one and
# need to validate their identity.
@@ -281,23 +280,52 @@ class PasswordRestServlet(RestServlet):
if requester.app_service:
params = body
else:
- params = await self.auth_handler.validate_user_via_ui_auth(
- requester,
+ try:
+ (
+ params,
+ session_id,
+ ) = await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
+ user_id = requester.user.to_string()
+ else:
+ requester = None
+ try:
+ result, params, session_id = await self.auth_handler.check_ui_auth(
+ [[LoginType.EMAIL_IDENTITY]],
request,
body,
self.hs.get_ip_from_request(request),
"modify your account password",
)
- user_id = requester.user.to_string()
- else:
- requester = None
- result, params, _ = await self.auth_handler.check_auth(
- [[LoginType.EMAIL_IDENTITY]],
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
@@ -322,12 +350,21 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- assert_params_in_dict(params, ["new_password_hash"])
- new_password_hash = params["new_password_hash"]
+ # If we have a password in this request, prefer it. Otherwise, there
+ # must be a password hash from an earlier request.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ else:
+ password_hash = await self.auth_handler.get_session_data(
+ session_id, "password_hash", None
+ )
+ if not password_hash:
+ raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
+
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password(
- user_id, new_password_hash, logout_devices, requester
+ user_id, password_hash, logout_devices, requester
)
if self.hs.config.shadow_server:
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 001f49fb3e..a883278068 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -26,6 +26,7 @@ import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
+ InteractiveAuthIncompleteError,
SynapseError,
ThreepidValidationError,
UnrecognizedRequestError,
@@ -385,6 +386,7 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
+ self._registration_enabled = self.hs.config.enable_registration
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -410,22 +412,6 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
- # we do basic sanity checks here because the auth layer will store these
- # in sessions. Pull out the username/password provided to us.
- desired_password_hash = None
- if "password" in body:
- password = body.pop("password")
- if not isinstance(password, str) or len(password) > 512:
- raise SynapseError(400, "Invalid password")
- self.password_policy_handler.validate_password(password)
-
- # If the password is valid, hash it and store it back on the body.
- # This ensures that only the hashed password is handled everywhere.
- if "password_hash" in body:
- raise SynapseError(400, "Unexpected property: password_hash")
- body["password_hash"] = await self.auth_handler.hash(password)
- desired_password_hash = body["password_hash"]
-
# We don't care about usernames for this deployment. In fact, the act
# of checking whether they exist already can leak metadata about
# which users are already registered.
@@ -442,6 +428,11 @@ class RegisterRestServlet(RestServlet):
if self.auth.has_access_token(request):
appservice = await self.auth.get_appservice_by_req(request)
+ # We need to retrieve the password early in order to pass it to
+ # application service registration
+ # This is specific to shadow server registration of users via an AS
+ password = body.pop("password", None)
+
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
@@ -461,21 +452,28 @@ class RegisterRestServlet(RestServlet):
if isinstance(desired_username, str):
result = await self._do_appservice_registration(
- desired_username,
- desired_password_hash,
- desired_display_name,
- access_token,
- body,
+ desired_username, password, desired_display_name, access_token, body
)
return 200, result # we throw for non 200 responses
# == Normal User Registration == (everyone else)
- if not self.hs.config.enable_registration:
+ if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled")
+ # Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)
- if "initial_device_display_name" in body and "password_hash" not in body:
+ # Pull out the provided password and do basic sanity checks early.
+ #
+ # Note that we remove the password from the body since the auth layer
+ # will store the body in the session and we don't want a plaintext
+ # password store there.
+ if password is not None:
+ if not isinstance(password, str) or len(password) > 512:
+ raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(password)
+
+ if "initial_device_display_name" in body and password is None:
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
@@ -485,6 +483,7 @@ class RegisterRestServlet(RestServlet):
session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
+ password_hash = None
if session_id:
# if we get a registered user id out of here, it means we previously
# registered a user for this session, so we could just return the
@@ -493,21 +492,43 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
+ # Extract the previously-hashed password from the session.
+ password_hash = await self.auth_handler.get_session_data(
+ session_id, "password_hash", None
+ )
- auth_result, params, session_id = await self.auth_handler.check_auth(
- self._registration_flows,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "register a new account",
- )
+ # Check if the user-interactive authentication flows are complete, if
+ # not this will raise a user-interactive auth error.
+ try:
+ auth_result, params, session_id = await self.auth_handler.check_ui_auth(
+ self._registration_flows,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "register a new account",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth.
+ #
+ # Hash the password and store it with the session since the client
+ # is not required to provide the password again.
+ #
+ # If a password hash was previously stored we will not attempt to
+ # re-hash and store it for efficiency. This assumes the password
+ # does not change throughout the authentication flow, but this
+ # should be fine since the data is meant to be consistent.
+ if not password_hash and password:
+ password_hash = await self.auth_handler.hash(password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
# Check that we're not trying to register a denied 3pid.
#
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
-
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
@@ -603,8 +624,12 @@ class RegisterRestServlet(RestServlet):
# don't re-register the threepids
registered = False
else:
- # NB: This may be from the auth handler and NOT from the POST
- assert_params_in_dict(params, ["password_hash"])
+ # If we have a password in this request, prefer it. Otherwise, there
+ # might be a password hash from an earlier request.
+ if password:
+ password_hash = await self.auth_handler.hash(password)
+ if not password_hash:
+ raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
if not self.hs.config.register_mxid_from_3pid:
desired_username = params.get("username", None)
@@ -613,7 +638,6 @@ class RegisterRestServlet(RestServlet):
pass
guest_access_token = params.get("guest_access_token", None)
- new_password_hash = params.get("password_hash", None)
if desired_username is not None:
desired_username = desired_username.lower()
@@ -655,7 +679,7 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
- password_hash=new_password_hash,
+ password_hash=password_hash,
guest_access_token=guest_access_token,
default_display_name=desired_display_name,
threepid=threepid,
@@ -677,8 +701,8 @@ class RegisterRestServlet(RestServlet):
params=params,
)
- # remember that we've now registered that user account, and with
- # what user ID (since the user may not have specified)
+ # Remember that the user account has been registered (and the user
+ # ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
@@ -702,12 +726,20 @@ class RegisterRestServlet(RestServlet):
return 200, {}
async def _do_appservice_registration(
- self, username, password_hash, display_name, as_token, body
+ self, username, password, display_name, as_token, body
):
# FIXME: appservice_register() is horribly duplicated with register()
# and they should probably just be combined together with a config flag.
+
+ if password:
+ # Hash the password
+ #
+ # In mainline hashing of the password was moved further on in the registration
+ # flow, but we need it here for the AS use-case of shadow servers
+ password = await self.auth_handler.hash(password)
+
user_id = await self.registration_handler.appservice_register(
- username, as_token, password_hash, display_name
+ username, as_token, password, display_name
)
result = await self._create_registration_details(user_id, body)
diff --git a/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
index e69de29bb2..96051ac179 100644
--- a/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
+++ b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * A flag saying whether the user owning the profile has been deactivated
+ * This really belongs on the users table, not here, but the users table
+ * stores users by their full user_id and profiles stores them by localpart,
+ * so we can't easily join between the two tables. Plus, the batch number
+ * realy ought to represent data in this table that has changed.
+ */
+ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
index e69de29bb2..7542ab8cbd 100644
--- a/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
+++ b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE UNIQUE INDEX profile_replication_status_idx ON profile_replication_status(host);
\ No newline at end of file
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index ceca4041e1..fce79b38f2 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -112,8 +112,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
+ @override_config({"enable_registration": False})
def test_POST_disabled_registration(self):
- self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|