From c2000ab35b76288a625f598d2382d4e3f29f65f6 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Wed, 4 Aug 2021 13:40:25 +0300 Subject: Add `get_userinfo_by_id` method to `ModuleApi` (#9581) Makes it easier to fetch user details in for example spam checker modules, without needing to use api._store or figure out database interactions. Signed-off-by: Jason Robinson --- synapse/storage/databases/main/registration.py | 30 +++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) (limited to 'synapse/storage/databases/main/registration.py') diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6ad1a0cf7f..14670c2881 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID +from synapse.types import UserID, UserInfo from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + """Deprecated: use get_userinfo_by_id instead""" return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, @@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_user_by_id", ) + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get a UserInfo object for a user by user ID. + + Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, + this method should be cached. + + Args: + user_id: The user to fetch user info for. + Returns: + `UserInfo` object if user found, otherwise `None`. + """ + user_data = await self.get_user_by_id(user_id) + if not user_data: + return None + return UserInfo( + appservice_id=user_data["appservice_id"], + consent_server_notice_sent=user_data["consent_server_notice_sent"], + consent_version=user_data["consent_version"], + creation_ts=user_data["creation_ts"], + is_admin=bool(user_data["admin"]), + is_deactivated=bool(user_data["deactivated"]), + is_guest=bool(user_data["is_guest"]), + is_shadow_banned=bool(user_data["shadow_banned"]), + user_id=UserID.from_string(user_data["name"]), + user_type=user_data["user_type"], + ) + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config -- cgit 1.5.1 From 3bcd525b46678ff228c4275acad47c12974c9a33 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 17 Aug 2021 12:56:11 +0200 Subject: Allow to edit `external_ids` by Edit User admin API (#10598) Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/10598.feature | 1 + docs/admin_api/user_admin_api.md | 40 +++-- synapse/rest/admin/users.py | 139 +++++++++------ synapse/storage/databases/main/registration.py | 22 +++ tests/rest/admin/test_user.py | 227 +++++++++++++++++++++---- 5 files changed, 340 insertions(+), 89 deletions(-) create mode 100644 changelog.d/10598.feature (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/10598.feature b/changelog.d/10598.feature new file mode 100644 index 0000000000..92c159118b --- /dev/null +++ b/changelog.d/10598.feature @@ -0,0 +1 @@ +Allow editing a user's `external_ids` via the "Edit User" admin API. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 4b5dd4685a..6a9335d6ec 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -81,6 +81,16 @@ with a body of: "address": "" } ], + "external_ids": [ + { + "auth_provider": "", + "external_id": "" + }, + { + "auth_provider": "", + "external_id": "" + } + ], "avatar_url": "", "admin": false, "deactivated": false @@ -90,26 +100,34 @@ with a body of: To use it, you will need to authenticate by providing an `access_token` for a server admin: [Admin API](../usage/administration/admin_api) +Returns HTTP status code: +- `201` - When a new user object was created. +- `200` - When a user was modified. + URL parameters: - `user_id`: fully-qualified user id: for example, `@user:server.com`. Body parameters: -- `password`, optional. If provided, the user's password is updated and all +- `password` - string, optional. If provided, the user's password is updated and all devices are logged out. - -- `displayname`, optional, defaults to the value of `user_id`. - -- `threepids`, optional, allows setting the third-party IDs (email, msisdn) +- `displayname` - string, optional, defaults to the value of `user_id`. +- `threepids` - array, optional, allows setting the third-party IDs (email, msisdn) + - `medium` - string. Kind of third-party ID, either `email` or `msisdn`. + - `address` - string. Value of third-party ID. belonging to a user. - -- `avatar_url`, optional, must be a +- `external_ids` - array, optional. Allow setting the identifier of the external identity + provider for SSO (Single sign-on). Details in + [Sample Configuration File](../usage/configuration/homeserver_sample_config.html) + section `sso` and `oidc_providers`. + - `auth_provider` - string. ID of the external identity provider. Value of `idp_id` + in homeserver configuration. + - `external_id` - string, user ID in the external identity provider. +- `avatar_url` - string, optional, must be a [MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris). - -- `admin`, optional, defaults to `false`. - -- `deactivated`, optional. If unspecified, deactivation state will be left +- `admin` - bool, optional, defaults to `false`. +- `deactivated` - bool, optional. If unspecified, deactivation state will be left unchanged on existing accounts and set to `false` for new accounts. A user cannot be erased by deactivating with this API. For details on deactivating users see [Deactivate Account](#deactivate-account). diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 41f21ba118..c885fd77ab 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -196,20 +196,57 @@ class UserRestServletV2(RestServlet): user = await self.admin_handler.get_user(target_user) user_id = target_user.to_string() + # check for required parameters for each threepid + threepids = body.get("threepids") + if threepids is not None: + for threepid in threepids: + assert_params_in_dict(threepid, ["medium", "address"]) + + # check for required parameters for each external_id + external_ids = body.get("external_ids") + if external_ids is not None: + for external_id in external_ids: + assert_params_in_dict(external_id, ["auth_provider", "external_id"]) + + user_type = body.get("user_type", None) + if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: + raise SynapseError(400, "Invalid user type") + + set_admin_to = body.get("admin", False) + if not isinstance(set_admin_to, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'admin' must be a boolean, if given", + Codes.BAD_JSON, + ) + + password = body.get("password", None) + if password is not None: + if not isinstance(password, str) or len(password) > 512: + raise SynapseError(400, "Invalid password") + + deactivate = body.get("deactivated", False) + if not isinstance(deactivate, bool): + raise SynapseError(400, "'deactivated' parameter is not of type boolean") + + # convert into List[Tuple[str, str]] + if external_ids is not None: + new_external_ids = [] + for external_id in external_ids: + new_external_ids.append( + (external_id["auth_provider"], external_id["external_id"]) + ) + if user: # modify user if "displayname" in body: await self.profile_handler.set_displayname( target_user, requester, body["displayname"], True ) - if "threepids" in body: - # check for required parameters for each threepid - for threepid in body["threepids"]: - assert_params_in_dict(threepid, ["medium", "address"]) - + if threepids is not None: # remove old threepids from user - threepids = await self.store.user_get_threepids(user_id) - for threepid in threepids: + old_threepids = await self.store.user_get_threepids(user_id) + for threepid in old_threepids: try: await self.auth_handler.delete_threepid( user_id, threepid["medium"], threepid["address"], None @@ -220,18 +257,39 @@ class UserRestServletV2(RestServlet): # add new threepids to user current_time = self.hs.get_clock().time_msec() - for threepid in body["threepids"]: + for threepid in threepids: await self.auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], current_time ) - if "avatar_url" in body and type(body["avatar_url"]) == str: + if external_ids is not None: + # get changed external_ids (added and removed) + cur_external_ids = await self.store.get_external_ids_by_user(user_id) + add_external_ids = set(new_external_ids) - set(cur_external_ids) + del_external_ids = set(cur_external_ids) - set(new_external_ids) + + # remove old external_ids + for auth_provider, external_id in del_external_ids: + await self.store.remove_user_external_id( + auth_provider, + external_id, + user_id, + ) + + # add new external_ids + for auth_provider, external_id in add_external_ids: + await self.store.record_user_external_id( + auth_provider, + external_id, + user_id, + ) + + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True ) if "admin" in body: - set_admin_to = bool(body["admin"]) if set_admin_to != user["admin"]: auth_user = requester.user if target_user == auth_user and not set_admin_to: @@ -239,29 +297,18 @@ class UserRestServletV2(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) - if "password" in body: - if not isinstance(body["password"], str) or len(body["password"]) > 512: - raise SynapseError(400, "Invalid password") - else: - new_password = body["password"] - logout_devices = True - - new_password_hash = await self.auth_handler.hash(new_password) - - await self.set_password_handler.set_password( - target_user.to_string(), - new_password_hash, - logout_devices, - requester, - ) + if password is not None: + logout_devices = True + new_password_hash = await self.auth_handler.hash(password) + + await self.set_password_handler.set_password( + target_user.to_string(), + new_password_hash, + logout_devices, + requester, + ) if "deactivated" in body: - deactivate = body["deactivated"] - if not isinstance(deactivate, bool): - raise SynapseError( - 400, "'deactivated' parameter is not of type boolean" - ) - if deactivate and not user["deactivated"]: await self.deactivate_account_handler.deactivate_account( target_user.to_string(), False, requester, by_admin=True @@ -285,36 +332,24 @@ class UserRestServletV2(RestServlet): return 200, user else: # create user - password = body.get("password") + displayname = body.get("displayname", None) + password_hash = None if password is not None: - if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") password_hash = await self.auth_handler.hash(password) - admin = body.get("admin", None) - user_type = body.get("user_type", None) - displayname = body.get("displayname", None) - - if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") - user_id = await self.registration_handler.register_user( localpart=target_user.localpart, password_hash=password_hash, - admin=bool(admin), + admin=set_admin_to, default_display_name=displayname, user_type=user_type, by_admin=True, ) - if "threepids" in body: - # check for required parameters for each threepid - for threepid in body["threepids"]: - assert_params_in_dict(threepid, ["medium", "address"]) - + if threepids is not None: current_time = self.hs.get_clock().time_msec() - for threepid in body["threepids"]: + for threepid in threepids: await self.auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], current_time ) @@ -334,6 +369,14 @@ class UserRestServletV2(RestServlet): data={}, ) + if external_ids is not None: + for auth_provider, external_id in new_external_ids: + await self.store.record_user_external_id( + auth_provider, + external_id, + user_id, + ) + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 14670c2881..c67bea81c6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -599,6 +599,28 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="record_user_external_id", ) + async def remove_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> None: + """Remove a mapping from an external user id to a mxid + + If the mapping is not found, this method does nothing. + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + await self.db_pool.simple_delete( + table="user_external_ids", + keyvalues={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="remove_user_external_id", + ) + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 13fab5579b..a736ec4754 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1240,56 +1240,114 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) - def test_get_user(self): + def test_invalid_parameter(self): """ - Test a simple get of a user. + If parameters are invalid, an error is returned. """ + + # admin not bool channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"admin": "not_bool"}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual("User", channel.json_body["displayname"]) - self._check_fields(channel.json_body) + # deactivated not bool + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": "not_bool"}, + ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - def test_get_user_with_sso(self): - """ - Test get a user with SSO details. - """ - self.get_success( - self.store.record_user_external_id( - "auth_provider1", "external_id1", self.other_user - ) + # password not str + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"password": True}, ) - self.get_success( - self.store.record_user_external_id( - "auth_provider2", "external_id2", self.other_user - ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # password not length + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"password": "x" * 513}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + # user_type not valid channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"user_type": "new type"}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual( - "external_id1", channel.json_body["external_ids"][0]["external_id"] + # external_ids not valid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"} + }, ) - self.assertEqual( - "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"external_ids": {"external_id": "id"}}, ) - self.assertEqual( - "external_id2", channel.json_body["external_ids"][1]["external_id"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # threepids not valid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": {"medium": "email", "wrong_address": "id"}}, ) - self.assertEqual( - "auth_provider2", channel.json_body["external_ids"][1]["auth_provider"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": {"address": "value"}}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_get_user(self): + """ + Test a simple get of a user. + """ + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) def test_create_server_admin(self): @@ -1353,6 +1411,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": False, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider1", + }, + ], "avatar_url": "mxc://fibble/wibble", } @@ -1368,6 +1432,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + ) self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self._check_fields(channel.json_body) @@ -1632,6 +1702,103 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + def test_set_external_id(self): + """ + Test setting external id for an other user. + """ + + # Add two external_ids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider1", + }, + { + "external_id": "external_id2", + "auth_provider": "auth_provider2", + }, + ] + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) + # result does not always have the same sort order, therefore it becomes sorted + self.assertEqual( + sorted(channel.json_body["external_ids"], key=lambda k: k["auth_provider"]), + [ + {"auth_provider": "auth_provider1", "external_id": "external_id1"}, + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + ], + ) + self._check_fields(channel.json_body) + + # Set a new and remove an external_id + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id2", + "auth_provider": "auth_provider2", + }, + { + "external_id": "external_id3", + "auth_provider": "auth_provider3", + }, + ] + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) + self.assertEqual( + channel.json_body["external_ids"], + [ + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + {"auth_provider": "auth_provider3", "external_id": "external_id3"}, + ], + ) + self._check_fields(channel.json_body) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual( + channel.json_body["external_ids"], + [ + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + {"auth_provider": "auth_provider3", "external_id": "external_id3"}, + ], + ) + self._check_fields(channel.json_body) + + # Remove external_ids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"external_ids": []}, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(0, len(channel.json_body["external_ids"])) + def test_deactivate_user(self): """ Test deactivating another user. -- cgit 1.5.1 From 220f901229a506a82aedc51c5923768bf935ea4f Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 19 Aug 2021 11:25:05 +0200 Subject: Remove not needed database updates in modify user admin API (#10627) --- changelog.d/10627.misc | 1 + docs/admin_api/user_admin_api.md | 8 +++- synapse/rest/admin/users.py | 55 ++++++++++++++--------- synapse/storage/databases/main/registration.py | 25 ++++++++--- tests/rest/admin/test_user.py | 62 ++++++++++++++++++++++++-- 5 files changed, 118 insertions(+), 33 deletions(-) create mode 100644 changelog.d/10627.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/10627.misc b/changelog.d/10627.misc new file mode 100644 index 0000000000..e6d314976e --- /dev/null +++ b/changelog.d/10627.misc @@ -0,0 +1 @@ +Remove not needed database updates in modify user admin API. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 6a9335d6ec..60dc913915 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -21,11 +21,15 @@ It returns a JSON body like the following: "threepids": [ { "medium": "email", - "address": "" + "address": "", + "added_at": 1586458409743, + "validated_at": 1586458409743 }, { "medium": "email", - "address": "" + "address": "", + "added_at": 1586458409743, + "validated_at": 1586458409743 } ], "avatar_url": "", diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 3c8a0c6883..c1a1ba645e 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet): if not isinstance(deactivate, bool): raise SynapseError(400, "'deactivated' parameter is not of type boolean") - # convert into List[Tuple[str, str]] + # convert List[Dict[str, str]] into Set[Tuple[str, str]] if external_ids is not None: - new_external_ids = [] - for external_id in external_ids: - new_external_ids.append( - (external_id["auth_provider"], external_id["external_id"]) - ) + new_external_ids = { + (external_id["auth_provider"], external_id["external_id"]) + for external_id in external_ids + } + + # convert List[Dict[str, str]] into Set[Tuple[str, str]] + if threepids is not None: + new_threepids = { + (threepid["medium"], threepid["address"]) for threepid in threepids + } if user: # modify user if "displayname" in body: @@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet): ) if threepids is not None: - # remove old threepids from user - old_threepids = await self.store.user_get_threepids(user_id) - for threepid in old_threepids: + # get changed threepids (added and removed) + # convert List[Dict[str, Any]] into Set[Tuple[str, str]] + cur_threepids = { + (threepid["medium"], threepid["address"]) + for threepid in await self.store.user_get_threepids(user_id) + } + add_threepids = new_threepids - cur_threepids + del_threepids = cur_threepids - new_threepids + + # remove old threepids + for medium, address in del_threepids: try: await self.auth_handler.delete_threepid( - user_id, threepid["medium"], threepid["address"], None + user_id, medium, address, None ) except Exception: logger.exception("Failed to remove threepids") raise SynapseError(500, "Failed to remove threepids") - # add new threepids to user + # add new threepids current_time = self.hs.get_clock().time_msec() - for threepid in threepids: + for medium, address in add_threepids: await self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], current_time + user_id, medium, address, current_time ) if external_ids is not None: # get changed external_ids (added and removed) - cur_external_ids = await self.store.get_external_ids_by_user(user_id) - add_external_ids = set(new_external_ids) - set(cur_external_ids) - del_external_ids = set(cur_external_ids) - set(new_external_ids) + cur_external_ids = set( + await self.store.get_external_ids_by_user(user_id) + ) + add_external_ids = new_external_ids - cur_external_ids + del_external_ids = cur_external_ids - new_external_ids # remove old external_ids for auth_provider, external_id in del_external_ids: @@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet): if threepids is not None: current_time = self.hs.get_clock().time_msec() - for threepid in threepids: + for medium, address in new_threepids: await self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], current_time + user_id, medium, address, current_time ) if ( self.hs.config.email_enable_notifs @@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet): kind="email", app_id="m.email", app_display_name="Email Notifications", - device_display_name=threepid["address"], - pushkey=threepid["address"], + device_display_name=address, + pushkey=address, lang=None, # We don't know a user's language here data={}, ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index c67bea81c6..469dd53e0c 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) return user_id - def get_user_id_by_threepid_txn(self, txn, medium, address): + def get_user_id_by_threepid_txn( + self, txn, medium: str, address: str + ) -> Optional[str]: """Returns user id from threepid Args: txn (cursor): - medium (str): threepid medium e.g. email - address (str): threepid address e.g. me@example.com + medium: threepid medium e.g. email + address: threepid address e.g. me@example.com Returns: - str|None: user id or None if no user id/threepid mapping exists + user id, or None if no user id/threepid mapping exists """ ret = self.db_pool.simple_select_one_txn( txn, @@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return ret["user_id"] return None - async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + async def user_add_threepid( + self, + user_id: str, + medium: str, + address: str, + validated_at: int, + added_at: int, + ) -> None: await self.db_pool.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - async def user_get_threepids(self, user_id): + async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]: return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, @@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "user_get_threepids", ) - async def user_delete_threepid(self, user_id, medium, address) -> None: + async def user_delete_threepid( + self, user_id: str, medium: str, address: str + ) -> None: await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ef77275238..ee204c404b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1431,12 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual( "external_id1", channel.json_body["external_ids"][0]["external_id"] ) self.assertEqual( "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] ) + self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self._check_fields(channel.json_body) @@ -1676,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase): Test setting threepid for an other user. """ - # Delete old and add new threepid to user + # Add two threepids to user channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}, + content={ + "threepids": [ + {"medium": "email", "address": "bob1@bob.bob"}, + {"medium": "email", "address": "bob2@bob.bob"}, + ], + }, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) + # result does not always have the same sort order, therefore it becomes sorted + sorted_result = sorted( + channel.json_body["threepids"], key=lambda k: k["address"] + ) + self.assertEqual("email", sorted_result[0]["medium"]) + self.assertEqual("bob1@bob.bob", sorted_result[0]["address"]) + self.assertEqual("email", sorted_result[1]["medium"]) + self.assertEqual("bob2@bob.bob", sorted_result[1]["address"]) + self._check_fields(channel.json_body) + + # Set a new and remove a threepid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "threepids": [ + {"medium": "email", "address": "bob2@bob.bob"}, + {"medium": "email", "address": "bob3@bob.bob"}, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("email", channel.json_body["threepids"][1]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"]) + self._check_fields(channel.json_body) # Get user channel = self.make_request( @@ -1698,8 +1735,24 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("email", channel.json_body["threepids"][1]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"]) + self._check_fields(channel.json_body) + + # Remove threepids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": []}, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self._check_fields(channel.json_body) def test_set_external_id(self): """ @@ -1778,6 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( channel.json_body["external_ids"], [ -- cgit 1.5.1 From 947dbbdfd1e0029da66f956d277b7c089928e1e7 Mon Sep 17 00:00:00 2001 From: Callum Brown Date: Sat, 21 Aug 2021 22:14:43 +0100 Subject: Implement MSC3231: Token authenticated registration (#10142) Signed-off-by: Callum Brown This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231). --- changelog.d/10142.feature | 1 + docs/SUMMARY.md | 1 + docs/sample_config.yaml | 15 + .../admin_api/registration_tokens.md | 295 +++++++++ docs/workers.md | 1 + synapse/api/constants.py | 1 + synapse/app/generic_worker.py | 6 +- synapse/config/ratelimiting.py | 11 + synapse/config/registration.py | 15 + synapse/handlers/ui_auth/__init__.py | 5 + synapse/handlers/ui_auth/checkers.py | 65 ++ synapse/res/templates/registration_token.html | 23 + synapse/rest/admin/__init__.py | 8 + synapse/rest/admin/registration_tokens.py | 321 ++++++++++ synapse/rest/client/auth.py | 24 + synapse/rest/client/register.py | 72 +++ synapse/storage/databases/main/registration.py | 316 +++++++++ synapse/storage/databases/main/ui_auth.py | 43 ++ .../main/delta/63/01create_registration_tokens.sql | 23 + tests/rest/admin/test_registration_tokens.py | 710 +++++++++++++++++++++ tests/rest/client/test_register.py | 434 +++++++++++++ 21 files changed, 2389 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10142.feature create mode 100644 docs/usage/administration/admin_api/registration_tokens.md create mode 100644 synapse/res/templates/registration_token.html create mode 100644 synapse/rest/admin/registration_tokens.py create mode 100644 synapse/storage/schema/main/delta/63/01create_registration_tokens.sql create mode 100644 tests/rest/admin/test_registration_tokens.py (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/10142.feature b/changelog.d/10142.feature new file mode 100644 index 0000000000..5353f6269d --- /dev/null +++ b/changelog.d/10142.feature @@ -0,0 +1 @@ +Add support for [MSC3231 - Token authenticated registration](https://github.com/matrix-org/matrix-doc/pull/3231). Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 634bc833ab..4fcd2b7852 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -53,6 +53,7 @@ - [Media](admin_api/media_admin_api.md) - [Purge History](admin_api/purge_history_api.md) - [Register Users](admin_api/register_api.md) + - [Registration Tokens](usage/administration/admin_api/registration_tokens.md) - [Manipulate Room Membership](admin_api/room_membership.md) - [Rooms](admin_api/rooms.md) - [Server Notices](admin_api/server_notices.md) diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 2b0c453242..935841dbfa 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -793,6 +793,8 @@ log_config: "CONFDIR/SERVERNAME.log.config" # is using # - one for registration that ratelimits registration requests based on the # client's IP address. +# - one for checking the validity of registration tokens that ratelimits +# requests based on the client's IP address. # - one for login that ratelimits login requests based on the client's IP # address. # - one for login that ratelimits login requests based on the account the @@ -821,6 +823,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # per_second: 0.17 # burst_count: 3 # +#rc_registration_token_validity: +# per_second: 0.1 +# burst_count: 5 +# #rc_login: # address: # per_second: 0.17 @@ -1169,6 +1175,15 @@ url_preview_accept_language: # #enable_3pid_lookup: true +# Require users to submit a token during registration. +# Tokens can be managed using the admin API: +# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html +# Note that `enable_registration` must be set to `true`. +# Disabling this option will not delete any tokens previously generated. +# Defaults to false. Uncomment the following to require tokens: +# +#registration_requires_token: true + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/docs/usage/administration/admin_api/registration_tokens.md b/docs/usage/administration/admin_api/registration_tokens.md new file mode 100644 index 0000000000..828c0277d6 --- /dev/null +++ b/docs/usage/administration/admin_api/registration_tokens.md @@ -0,0 +1,295 @@ +# Registration Tokens + +This API allows you to manage tokens which can be used to authenticate +registration requests, as proposed in [MSC3231](https://github.com/govynnus/matrix-doc/blob/token-registration/proposals/3231-token-authenticated-registration.md). +To use it, you will need to enable the `registration_requires_token` config +option, and authenticate by providing an `access_token` for a server admin: +see [Admin API](../../usage/administration/admin_api). +Note that this API is still experimental; not all clients may support it yet. + + +## Registration token objects + +Most endpoints make use of JSON objects that contain details about tokens. +These objects have the following fields: +- `token`: The token which can be used to authenticate registration. +- `uses_allowed`: The number of times the token can be used to complete a + registration before it becomes invalid. +- `pending`: The number of pending uses the token has. When someone uses + the token to authenticate themselves, the pending counter is incremented + so that the token is not used more than the permitted number of times. + When the person completes registration the pending counter is decremented, + and the completed counter is incremented. +- `completed`: The number of times the token has been used to successfully + complete a registration. +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + To convert this into a human-readable form you can remove the milliseconds + and use the `date` command. For example, `date -d '@1625394937'`. + + +## List all tokens + +Lists all tokens and details about them. If the request is successful, the top +level JSON object will have a `registration_tokens` key which is an array of +registration token objects. + +``` +GET /_synapse/admin/v1/registration_tokens +``` + +Optional query parameters: +- `valid`: `true` or `false`. If `true`, only valid tokens are returned. + If `false`, only tokens that have expired or have had all uses exhausted are + returned. If omitted, all tokens are returned regardless of validity. + +Example: + +``` +GET /_synapse/admin/v1/registration_tokens +``` +``` +200 OK + +{ + "registration_tokens": [ + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + }, + { + "token": "pqrs", + "uses_allowed": 2, + "pending": 1, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC + } + ] +} +``` + +Example using the `valid` query parameter: + +``` +GET /_synapse/admin/v1/registration_tokens?valid=false +``` +``` +200 OK + +{ + "registration_tokens": [ + { + "token": "pqrs", + "uses_allowed": 2, + "pending": 1, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC + } + ] +} +``` + + +## Get one token + +Get details about a single token. If the request is successful, the response +body will be a registration token object. + +``` +GET /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to return details of. + +Example: + +``` +GET /_synapse/admin/v1/registration_tokens/abcd +``` +``` +200 OK + +{ + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null +} +``` + + +## Create token + +Create a new registration token. If the request is successful, the newly created +token will be returned as a registration token object in the response body. + +``` +POST /_synapse/admin/v1/registration_tokens/new +``` + +The request body must be a JSON object and can contain the following fields: +- `token`: The registration token. A string of no more than 64 characters that + consists only of characters matched by the regex `[A-Za-z0-9-_]`. + Default: randomly generated. +- `uses_allowed`: The integer number of times the token can be used to complete + a registration before it becomes invalid. + Default: `null` (unlimited uses). +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + You could use, for example, `date '+%s000' -d 'tomorrow'`. + Default: `null` (token does not expire). +- `length`: The length of the token randomly generated if `token` is not + specified. Must be between 1 and 64 inclusive. Default: `16`. + +If a field is omitted the default is used. + +Example using defaults: + +``` +POST /_synapse/admin/v1/registration_tokens/new + +{} +``` +``` +200 OK + +{ + "token": "0M-9jbkf2t_Tgiw1", + "uses_allowed": null, + "pending": 0, + "completed": 0, + "expiry_time": null +} +``` + +Example specifying some fields: + +``` +POST /_synapse/admin/v1/registration_tokens/new + +{ + "token": "defg", + "uses_allowed": 1 +} +``` +``` +200 OK + +{ + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": null +} +``` + + +## Update token + +Update the number of allowed uses or expiry time of a token. If the request is +successful, the updated token will be returned as a registration token object +in the response body. + +``` +PUT /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to update. + +The request body must be a JSON object and can contain the following fields: +- `uses_allowed`: The integer number of times the token can be used to complete + a registration before it becomes invalid. By setting `uses_allowed` to `0` + the token can be easily made invalid without deleting it. + If `null` the token will have an unlimited number of uses. +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + If `null` the token will not expire. + +If a field is omitted its value is not modified. + +Example: + +``` +PUT /_synapse/admin/v1/registration_tokens/defg + +{ + "expiry_time": 4781243146000 // 2121-07-06 11:05:46 UTC +} +``` +``` +200 OK + +{ + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": 4781243146000 +} +``` + + +## Delete token + +Delete a registration token. If the request is successful, the response body +will be an empty JSON object. + +``` +DELETE /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to delete. + +Example: + +``` +DELETE /_synapse/admin/v1/registration_tokens/wxyz +``` +``` +200 OK + +{} +``` + + +## Errors + +If a request fails a "standard error response" will be returned as defined in +the [Matrix Client-Server API specification](https://matrix.org/docs/spec/client_server/r0.6.1#api-standards). + +For example, if the token specified in a path parameter does not exist a +`404 Not Found` error will be returned. + +``` +GET /_synapse/admin/v1/registration_tokens/1234 +``` +``` +404 Not Found + +{ + "errcode": "M_NOT_FOUND", + "error": "No such registration token: 1234" +} +``` diff --git a/docs/workers.md b/docs/workers.md index 2e63f03452..3121241894 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -236,6 +236,7 @@ expressions: # Registration/login requests ^/_matrix/client/(api/v1|r0|unstable)/login$ ^/_matrix/client/(r0|unstable)/register$ + ^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$ # Event sending requests ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact diff --git a/synapse/api/constants.py b/synapse/api/constants.py index e0e24fddac..829061c870 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -79,6 +79,7 @@ class LoginType: TERMS = "m.login.terms" SSO = "m.login.sso" DUMMY = "m.login.dummy" + REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" # This is used in the `type` parameter for /register when called by diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 845e6a8220..fd2626dbe1 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -95,7 +95,10 @@ from synapse.rest.client.profile import ( ProfileRestServlet, ) from synapse.rest.client.push_rule import PushRuleRestServlet -from synapse.rest.client.register import RegisterRestServlet +from synapse.rest.client.register import ( + RegisterRestServlet, + RegistrationTokenValidityRestServlet, +) from synapse.rest.client.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.client.voip import VoipRestServlet @@ -279,6 +282,7 @@ class GenericWorkerServer(HomeServer): resource = JsonResource(self, canonical_json=False) RegisterRestServlet(self).register(resource) + RegistrationTokenValidityRestServlet(self).register(resource) login.register_servlets(self, resource) ThreepidRestServlet(self).register(resource) DevicesRestServlet(self).register(resource) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 7a8d5851c4..f856327bd8 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -79,6 +79,11 @@ class RatelimitConfig(Config): self.rc_registration = RateLimitConfig(config.get("rc_registration", {})) + self.rc_registration_token_validity = RateLimitConfig( + config.get("rc_registration_token_validity", {}), + defaults={"per_second": 0.1, "burst_count": 5}, + ) + rc_login_config = config.get("rc_login", {}) self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {})) self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {})) @@ -143,6 +148,8 @@ class RatelimitConfig(Config): # is using # - one for registration that ratelimits registration requests based on the # client's IP address. + # - one for checking the validity of registration tokens that ratelimits + # requests based on the client's IP address. # - one for login that ratelimits login requests based on the client's IP # address. # - one for login that ratelimits login requests based on the account the @@ -171,6 +178,10 @@ class RatelimitConfig(Config): # per_second: 0.17 # burst_count: 3 # + #rc_registration_token_validity: + # per_second: 0.1 + # burst_count: 5 + # #rc_login: # address: # per_second: 0.17 diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 0ad919b139..7cffdacfa5 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -33,6 +33,9 @@ class RegistrationConfig(Config): self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) + self.registration_requires_token = config.get( + "registration_requires_token", False + ) self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) @@ -140,6 +143,9 @@ class RegistrationConfig(Config): "mechanism by removing the `access_token_lifetime` option." ) + # The fallback template used for authenticating using a registration token + self.registration_token_template = self.read_template("registration_token.html") + # The success template used during fallback auth. self.fallback_success_template = self.read_template("auth_success.html") @@ -199,6 +205,15 @@ class RegistrationConfig(Config): # #enable_3pid_lookup: true + # Require users to submit a token during registration. + # Tokens can be managed using the admin API: + # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html + # Note that `enable_registration` must be set to `true`. + # Disabling this option will not delete any tokens previously generated. + # Defaults to false. Uncomment the following to require tokens: + # + #registration_requires_token: true + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py index 4c3b669fae..13b0c61d2e 100644 --- a/synapse/handlers/ui_auth/__init__.py +++ b/synapse/handlers/ui_auth/__init__.py @@ -34,3 +34,8 @@ class UIAuthSessionDataConstants: # used by validate_user_via_ui_auth to store the mxid of the user we are validating # for. REQUEST_USER_ID = "request_user_id" + + # used during registration to store the registration token used (if required) so that: + # - we can prevent a token being used twice by one session + # - we can 'use up' the token after registration has successfully completed + REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 270541cc76..d3828dec6b 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -241,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): return await self._check_threepid("msisdn", authdict) +class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.REGISTRATION_TOKEN + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.hs = hs + self._enabled = bool(hs.config.registration_requires_token) + self.store = hs.get_datastore() + + def is_enabled(self) -> bool: + return self._enabled + + async def check_auth(self, authdict: dict, clientip: str) -> Any: + if "token" not in authdict: + raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM) + if not isinstance(authdict["token"], str): + raise LoginError( + 400, "Registration token must be a string", Codes.INVALID_PARAM + ) + if "session" not in authdict: + raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM) + + # Get these here to avoid cyclic dependencies + from synapse.handlers.ui_auth import UIAuthSessionDataConstants + + auth_handler = self.hs.get_auth_handler() + + session = authdict["session"] + token = authdict["token"] + + # If the LoginType.REGISTRATION_TOKEN stage has already been completed, + # return early to avoid incrementing `pending` again. + stored_token = await auth_handler.get_session_data( + session, UIAuthSessionDataConstants.REGISTRATION_TOKEN + ) + if stored_token: + if token != stored_token: + raise LoginError( + 400, "Registration token has changed", Codes.INVALID_PARAM + ) + else: + return token + + if await self.store.registration_token_is_valid(token): + # Increment pending counter, so that if token has limited uses it + # can't be used up by someone else in the meantime. + await self.store.set_registration_token_pending(token) + # Store the token in the UIA session, so that once registration + # is complete `completed` can be incremented. + await auth_handler.set_session_data( + session, + UIAuthSessionDataConstants.REGISTRATION_TOKEN, + token, + ) + # The token will be stored as the result of the authentication stage + # in ui_auth_sessions_credentials. This allows the pending counter + # for tokens to be decremented when expired sessions are deleted. + return token + else: + raise LoginError( + 401, "Invalid registration token", errcode=Codes.UNAUTHORIZED + ) + + INTERACTIVE_AUTH_CHECKERS = [ DummyAuthChecker, TermsAuthChecker, RecaptchaAuthChecker, EmailIdentityAuthChecker, MsisdnAuthChecker, + RegistrationTokenAuthChecker, ] """A list of UserInteractiveAuthChecker classes""" diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html new file mode 100644 index 0000000000..4577ce1702 --- /dev/null +++ b/synapse/res/templates/registration_token.html @@ -0,0 +1,23 @@ + + +Authentication + + + + +
+
+ {% if error is defined %} +

Error: {{ error }}

+ {% endif %} +

+ Please enter a registration token. +

+ + + +
+
+ + diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 7f3051aef1..6e1c8736e1 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -36,6 +36,11 @@ from synapse.rest.admin.event_reports import ( ) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo +from synapse.rest.admin.registration_tokens import ( + ListRegistrationTokensRestServlet, + NewRegistrationTokenRestServlet, + RegistrationTokenRestServlet, +) from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, ForwardExtremitiesRestServlet, @@ -238,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomEventContextServlet(hs).register(http_server) RateLimitRestServlet(hs).register(http_server) UsernameAvailableRestServlet(hs).register(http_server) + ListRegistrationTokensRestServlet(hs).register(http_server) + NewRegistrationTokenRestServlet(hs).register(http_server) + RegistrationTokenRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py new file mode 100644 index 0000000000..5a1c929d85 --- /dev/null +++ b/synapse/rest/admin/registration_tokens.py @@ -0,0 +1,321 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import string +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_boolean, + parse_json_object_from_request, +) +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ListRegistrationTokensRestServlet(RestServlet): + """List registration tokens. + + To list all tokens: + + GET /_synapse/admin/v1/registration_tokens + + 200 OK + + { + "registration_tokens": [ + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 + } + ] + } + + The optional query parameter `valid` can be used to filter the response. + If it is `true`, only valid tokens are returned. If it is `false`, only + tokens that have expired or have had all uses exhausted are returned. + If it is omitted, all tokens are returned regardless of validity. + """ + + PATTERNS = admin_patterns("/registration_tokens$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + valid = parse_boolean(request, "valid") + token_list = await self.store.get_registration_tokens(valid) + return 200, {"registration_tokens": token_list} + + +class NewRegistrationTokenRestServlet(RestServlet): + """Create a new registration token. + + For example, to create a token specifying some fields: + + POST /_synapse/admin/v1/registration_tokens/new + + { + "token": "defg", + "uses_allowed": 1 + } + + 200 OK + + { + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": null + } + + Defaults are used for any fields not specified. + """ + + PATTERNS = admin_patterns("/registration_tokens/new$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + # A string of all the characters allowed to be in a registration_token + self.allowed_chars = string.ascii_letters + string.digits + "-_" + self.allowed_chars_set = set(self.allowed_chars) + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request) + + if "token" in body: + token = body["token"] + if not isinstance(token, str): + raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM) + if not (0 < len(token) <= 64): + raise SynapseError( + 400, + "token must not be empty and must not be longer than 64 characters", + Codes.INVALID_PARAM, + ) + if not set(token).issubset(self.allowed_chars_set): + raise SynapseError( + 400, + "token must consist only of characters matched by the regex [A-Za-z0-9-_]", + Codes.INVALID_PARAM, + ) + + else: + # Get length of token to generate (default is 16) + length = body.get("length", 16) + if not isinstance(length, int): + raise SynapseError( + 400, "length must be an integer", Codes.INVALID_PARAM + ) + if not (0 < length <= 64): + raise SynapseError( + 400, + "length must be greater than zero and not greater than 64", + Codes.INVALID_PARAM, + ) + + # Generate token + token = await self.store.generate_registration_token( + length, self.allowed_chars + ) + + uses_allowed = body.get("uses_allowed", None) + if not ( + uses_allowed is None + or (isinstance(uses_allowed, int) and uses_allowed >= 0) + ): + raise SynapseError( + 400, + "uses_allowed must be a non-negative integer or null", + Codes.INVALID_PARAM, + ) + + expiry_time = body.get("expiry_time", None) + if not isinstance(expiry_time, (int, type(None))): + raise SynapseError( + 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + ) + if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + raise SynapseError( + 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + ) + + created = await self.store.create_registration_token( + token, uses_allowed, expiry_time + ) + if not created: + raise SynapseError( + 400, f"Token already exists: {token}", Codes.INVALID_PARAM + ) + + resp = { + "token": token, + "uses_allowed": uses_allowed, + "pending": 0, + "completed": 0, + "expiry_time": expiry_time, + } + return 200, resp + + +class RegistrationTokenRestServlet(RestServlet): + """Retrieve, update, or delete the given token. + + For example, + + to retrieve a token: + + GET /_synapse/admin/v1/registration_tokens/abcd + + 200 OK + + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + } + + + to update a token: + + PUT /_synapse/admin/v1/registration_tokens/defg + + { + "uses_allowed": 5, + "expiry_time": 4781243146000 + } + + 200 OK + + { + "token": "defg", + "uses_allowed": 5, + "pending": 0, + "completed": 0, + "expiry_time": 4781243146000 + } + + + to delete a token: + + DELETE /_synapse/admin/v1/registration_tokens/wxyz + + 200 OK + + {} + """ + + PATTERNS = admin_patterns("/registration_tokens/(?P[^/]*)$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.clock = hs.get_clock() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: + """Retrieve a registration token.""" + await assert_requester_is_admin(self.auth, request) + token_info = await self.store.get_one_registration_token(token) + + # If no result return a 404 + if token_info is None: + raise NotFoundError(f"No such registration token: {token}") + + return 200, token_info + + async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: + """Update a registration token.""" + await assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request) + new_attributes = {} + + # Only add uses_allowed to new_attributes if it is present and valid + if "uses_allowed" in body: + uses_allowed = body["uses_allowed"] + if not ( + uses_allowed is None + or (isinstance(uses_allowed, int) and uses_allowed >= 0) + ): + raise SynapseError( + 400, + "uses_allowed must be a non-negative integer or null", + Codes.INVALID_PARAM, + ) + new_attributes["uses_allowed"] = uses_allowed + + if "expiry_time" in body: + expiry_time = body["expiry_time"] + if not isinstance(expiry_time, (int, type(None))): + raise SynapseError( + 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + ) + if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + raise SynapseError( + 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + ) + new_attributes["expiry_time"] = expiry_time + + if len(new_attributes) == 0: + # Nothing to update, get token info to return + token_info = await self.store.get_one_registration_token(token) + else: + token_info = await self.store.update_registration_token( + token, new_attributes + ) + + # If no result return a 404 + if token_info is None: + raise NotFoundError(f"No such registration token: {token}") + + return 200, token_info + + async def on_DELETE( + self, request: SynapseRequest, token: str + ) -> Tuple[int, JsonDict]: + """Delete a registration token.""" + await assert_requester_is_admin(self.auth, request) + + if await self.store.delete_registration_token(token): + return 200, {} + + raise NotFoundError(f"No such registration token: {token}") diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 73284e48ec..91800c0278 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.recaptcha_template self.terms_template = hs.config.terms_template + self.registration_token_template = hs.config.registration_token_template self.success_template = hs.config.fallback_success_template async def on_GET(self, request, stagetype): @@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet): # re-authenticate with their SSO provider. html = await self.auth_handler.start_sso_ui_auth(request, session) + elif stagetype == LoginType.REGISTRATION_TOKEN: + html = self.registration_token_template.render( + session=session, + myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web", + ) + else: raise SynapseError(404, "Unknown auth stage type") @@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet): # The SSO fallback workflow should not post here, raise SynapseError(404, "Fallback SSO auth does not support POST requests.") + elif stagetype == LoginType.REGISTRATION_TOKEN: + token = parse_string(request, "token", required=True) + authdict = {"session": session, "token": token} + + try: + await self.auth_handler.add_oob_auth( + LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP() + ) + except LoginError as e: + html = self.registration_token_template.render( + session=session, + myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web", + error=e.msg, + ) + else: + html = self.success_template.render() + else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 58b8e8f261..2781a0ea96 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -28,6 +28,7 @@ from synapse.api.errors import ( ThreepidValidationError, UnrecognizedRequestError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config import ConfigError from synapse.config.captcha import CaptchaConfig from synapse.config.consent import ConsentConfig @@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet): return 200, {"available": True} +class RegistrationTokenValidityRestServlet(RestServlet): + """Check the validity of a registration token. + + Example: + + GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd + + 200 OK + + { + "valid": true + } + """ + + PATTERNS = client_patterns( + f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity", + releases=(), + unstable=True, + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.store = hs.get_datastore() + self.ratelimiter = Ratelimiter( + store=self.store, + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second, + burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, + ) + + async def on_GET(self, request): + await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) + + if not self.hs.config.enable_registration: + raise SynapseError( + 403, "Registration has been disabled", errcode=Codes.FORBIDDEN + ) + + token = parse_string(request, "token", required=True) + valid = await self.store.registration_token_is_valid(token) + + return 200, {"valid": valid} + + class RegisterRestServlet(RestServlet): PATTERNS = client_patterns("/register$") @@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet): ) if registered: + # Check if a token was used to authenticate registration + registration_token = await self.auth_handler.get_session_data( + session_id, + UIAuthSessionDataConstants.REGISTRATION_TOKEN, + ) + if registration_token: + # Increment the `completed` counter for the token + await self.store.use_registration_token(registration_token) + # Indicate that the token has been successfully used so that + # pending is not decremented again when expiring old UIA sessions. + await self.store.mark_ui_auth_stage_complete( + session_id, + LoginType.REGISTRATION_TOKEN, + True, + ) + await self.registration_handler.post_registration_actions( user_id=registered_user_id, auth_result=auth_result, @@ -868,6 +934,11 @@ def _calculate_registration_flows( for flow in flows: flow.insert(0, LoginType.RECAPTCHA) + # Prepend registration token to all flows if we're requiring a token + if config.registration_requires_token: + for flow in flows: + flow.insert(0, LoginType.REGISTRATION_TOKEN) + return flows @@ -876,4 +947,5 @@ def register_servlets(hs, http_server): MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server) RegistrationSubmitTokenServlet(hs).register(http_server) + RegistrationTokenValidityRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 469dd53e0c..a6517962f6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="update_access_token_last_validated", ) + async def registration_token_is_valid(self, token: str) -> bool: + """Checks if a token can be used to authenticate a registration. + + Args: + token: The registration token to be checked + Returns: + True if the token is valid, False otherwise. + """ + res = await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + ) + + # Check if the token exists + if res is None: + return False + + # Check if the token has expired + now = self._clock.time_msec() + if res["expiry_time"] and res["expiry_time"] < now: + return False + + # Check if the token has been used up + if ( + res["uses_allowed"] + and res["pending"] + res["completed"] >= res["uses_allowed"] + ): + return False + + # Otherwise, the token is valid + return True + + async def set_registration_token_pending(self, token: str) -> None: + """Increment the pending registrations counter for a token. + + Args: + token: The registration token pending use + """ + + def _set_registration_token_pending_txn(txn): + pending = self.db_pool.simple_select_one_onecol_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={"pending": pending + 1}, + ) + + return await self.db_pool.runInteraction( + "set_registration_token_pending", _set_registration_token_pending_txn + ) + + async def use_registration_token(self, token: str) -> None: + """Complete a use of the given registration token. + + The `pending` counter will be decremented, and the `completed` + counter will be incremented. + + Args: + token: The registration token to be 'used' + """ + + def _use_registration_token_txn(txn): + # Normally, res is Optional[Dict[str, Any]]. + # Override type because the return type is only optional if + # allow_none is True, and we don't want mypy throwing errors + # about None not being indexable. + res: Dict[str, Any] = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) # type: ignore + + # Decrement pending and increment completed + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={ + "completed": res["completed"] + 1, + "pending": res["pending"] - 1, + }, + ) + + return await self.db_pool.runInteraction( + "use_registration_token", _use_registration_token_txn + ) + + async def get_registration_tokens( + self, valid: Optional[bool] = None + ) -> List[Dict[str, Any]]: + """List all registration tokens. Used by the admin API. + + Args: + valid: If True, only valid tokens are returned. + If False, only invalid tokens are returned. + Default is None: return all tokens regardless of validity. + + Returns: + A list of dicts, each containing details of a token. + """ + + def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]): + if valid is None: + # Return all tokens regardless of validity + txn.execute("SELECT * FROM registration_tokens") + + elif valid: + # Select valid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "(uses_allowed > pending + completed OR uses_allowed IS NULL) " + "AND (expiry_time > ? OR expiry_time IS NULL)" + ) + txn.execute(sql, [now]) + + else: + # Select invalid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "uses_allowed <= pending + completed OR expiry_time <= ?" + ) + txn.execute(sql, [now]) + + return self.db_pool.cursor_to_dict(txn) + + return await self.db_pool.runInteraction( + "select_registration_tokens", + select_registration_tokens_txn, + self._clock.time_msec(), + valid, + ) + + async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]: + """Get info about the given registration token. Used by the admin API. + + Args: + token: The token to retrieve information about. + + Returns: + A dict, or None if token doesn't exist. + """ + return await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + desc="get_one_registration_token", + ) + + async def generate_registration_token( + self, length: int, chars: str + ) -> Optional[str]: + """Generate a random registration token. Used by the admin API. + + Args: + length: The length of the token to generate. + chars: A string of the characters allowed in the generated token. + + Returns: + The generated token. + + Raises: + SynapseError if a unique registration token could still not be + generated after a few tries. + """ + # Make a few attempts at generating a unique token of the required + # length before failing. + for _i in range(3): + # Generate token + token = "".join(random.choices(chars, k=length)) + + # Check if the token already exists + existing_token = await self.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="token", + allow_none=True, + desc="check_if_registration_token_exists", + ) + + if existing_token is None: + # The generated token doesn't exist yet, return it + return token + + raise SynapseError( + 500, + "Unable to generate a unique registration token. Try again with a greater length", + Codes.UNKNOWN, + ) + + async def create_registration_token( + self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int] + ) -> bool: + """Create a new registration token. Used by the admin API. + + Args: + token: The token to create. + uses_allowed: The number of times the token can be used to complete + a registration before it becomes invalid. A value of None indicates + unlimited uses. + expiry_time: The latest time the token is valid. Given as the + number of milliseconds since 1970-01-01 00:00:00 UTC. A value of + None indicates that the token does not expire. + + Returns: + Whether the row was inserted or not. + """ + + def _create_registration_token_txn(txn): + row = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["token"], + allow_none=True, + ) + + if row is not None: + # Token already exists + return False + + self.db_pool.simple_insert_txn( + txn, + "registration_tokens", + values={ + "token": token, + "uses_allowed": uses_allowed, + "pending": 0, + "completed": 0, + "expiry_time": expiry_time, + }, + ) + + return True + + return await self.db_pool.runInteraction( + "create_registration_token", _create_registration_token_txn + ) + + async def update_registration_token( + self, token: str, updatevalues: Dict[str, Optional[int]] + ) -> Optional[Dict[str, Any]]: + """Update a registration token. Used by the admin API. + + Args: + token: The token to update. + updatevalues: A dict with the fields to update. E.g.: + `{"uses_allowed": 3}` to update just uses_allowed, or + `{"uses_allowed": 3, "expiry_time": None}` to update both. + This is passed straight to simple_update_one. + + Returns: + A dict with all info about the token, or None if token doesn't exist. + """ + + def _update_registration_token_txn(txn): + try: + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues=updatevalues, + ) + except StoreError: + # Update failed because token does not exist + return None + + # Get all info about the token so it can be sent in the response + return self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=[ + "token", + "uses_allowed", + "pending", + "completed", + "expiry_time", + ], + allow_none=True, + ) + + return await self.db_pool.runInteraction( + "update_registration_token", _update_registration_token_txn + ) + + async def delete_registration_token(self, token: str) -> bool: + """Delete a registration token. Used by the admin API. + + Args: + token: The token to delete. + + Returns: + Whether the token was successfully deleted or not. + """ + try: + await self.db_pool.simple_delete_one( + "registration_tokens", + keyvalues={"token": token}, + desc="delete_registration_token", + ) + except StoreError: + # Deletion failed because token does not exist + return False + + return True + @cached() async def mark_access_token_as_used(self, token_id: int) -> None: """ diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 38bfdf5dad..4d6bbc94c7 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import attr +from synapse.api.constants import LoginType from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction @@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore): keyvalues={}, ) + # If a registration token was used, decrement the pending counter + # before deleting the session. + rows = self.db_pool.simple_select_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, + retcols=["result"], + ) + + # Get the tokens used and how much pending needs to be decremented by. + token_counts: Dict[str, int] = {} + for r in rows: + # If registration was successfully completed, the result of the + # registration token stage for that session will be True. + # If a token was used to authenticate, but registration was + # never completed, the result will be the token used. + token = db_to_json(r["result"]) + if isinstance(token, str): + token_counts[token] = token_counts.get(token, 0) + 1 + + # Update the `pending` counters. + if len(token_counts) > 0: + token_rows = self.db_pool.simple_select_many_txn( + txn, + table="registration_tokens", + column="token", + iterable=list(token_counts.keys()), + keyvalues={}, + retcols=["token", "pending"], + ) + for token_row in token_rows: + token = token_row["token"] + new_pending = token_row["pending"] - token_counts[token] + self.db_pool.simple_update_one_txn( + txn, + table="registration_tokens", + keyvalues={"token": token}, + updatevalues={"pending": new_pending}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql new file mode 100644 index 0000000000..ee6cf958f4 --- /dev/null +++ b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql @@ -0,0 +1,23 @@ +/* Copyright 2021 Callum Brown + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS registration_tokens( + token TEXT NOT NULL, -- The token that can be used for authentication. + uses_allowed INT, -- The total number of times this token can be used. NULL if no limit. + pending INT NOT NULL, -- The number of in progress registrations using this token. + completed INT NOT NULL, -- The number of times this token has been used to complete a registration. + expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire. + UNIQUE (token) +); diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py new file mode 100644 index 0000000000..4927321e5a --- /dev/null +++ b/tests/rest/admin/test_registration_tokens.py @@ -0,0 +1,710 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import string + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login + +from tests import unittest + + +class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.url = "/_synapse/admin/v1/registration_tokens" + + def _new_token(self, **kwargs): + """Helper function to create a token.""" + token = kwargs.get( + "token", + "".join(random.choices(string.ascii_letters, k=8)), + ) + self.get_success( + self.store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": kwargs.get("uses_allowed", None), + "pending": kwargs.get("pending", 0), + "completed": kwargs.get("completed", 0), + "expiry_time": kwargs.get("expiry_time", None), + }, + ) + ) + return token + + # CREATION + + def test_create_no_auth(self): + """Try to create a token without authentication.""" + channel = self.make_request("POST", self.url + "/new", {}) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_create_requester_not_admin(self): + """Try to create a token while not an admin.""" + channel = self.make_request( + "POST", + self.url + "/new", + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_create_using_defaults(self): + """Create a token using all the defaults.""" + channel = self.make_request( + "POST", + self.url + "/new", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 16) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_specifying_fields(self): + """Create a token specifying the value of all fields.""" + data = { + "token": "abcd", + "uses_allowed": 1, + "expiry_time": self.clock.time_msec() + 1000000, + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["token"], "abcd") + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_with_null_value(self): + """Create a token specifying unlimited uses and no expiry.""" + data = { + "uses_allowed": None, + "expiry_time": None, + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 16) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_token_too_long(self): + """Check token longer than 64 chars is invalid.""" + data = {"token": "a" * 65} + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_token_invalid_chars(self): + """Check you can't create token with invalid characters.""" + data = { + "token": "abc/def", + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_token_already_exists(self): + """Check you can't create token that already exists.""" + data = { + "token": "abcd", + } + + channel1 = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"]) + + channel2 = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"]) + self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_unable_to_generate_token(self): + """Check right error is raised when server can't generate unique token.""" + # Create all possible single character tokens + tokens = [] + for c in string.ascii_letters + string.digits + "-_": + tokens.append( + { + "token": c, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + } + ) + self.get_success( + self.store.db_pool.simple_insert_many( + "registration_tokens", + tokens, + "create_all_registration_tokens", + ) + ) + + # Check creating a single character token fails with a 500 status code + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 1}, + access_token=self.admin_user_tok, + ) + self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"]) + + def test_create_uses_allowed(self): + """Check you can only create a token with good values for uses_allowed.""" + # Should work with 0 (token is invalid from the start) + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 0) + + # Should fail with negative integer + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with float + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": 1.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_expiry_time(self): + """Check you can't create a token with an invalid expiry_time.""" + # Should fail with a time in the past + channel = self.make_request( + "POST", + self.url + "/new", + {"expiry_time": self.clock.time_msec() - 10000}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with float + channel = self.make_request( + "POST", + self.url + "/new", + {"expiry_time": self.clock.time_msec() + 1000000.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_length(self): + """Check you can only generate a token with a valid length.""" + # Should work with 64 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 64}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 64) + + # Should fail with 0 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a negative integer + channel = self.make_request( + "POST", + self.url + "/new", + {"length": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a float + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 8.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with 65 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 65}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # UPDATING + + def test_update_no_auth(self): + """Try to update a token without authentication.""" + channel = self.make_request( + "PUT", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_update_requester_not_admin(self): + """Try to update a token while not an admin.""" + channel = self.make_request( + "PUT", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_update_non_existent(self): + """Try to update a token that doesn't exist.""" + channel = self.make_request( + "PUT", + self.url + "/1234", + {"uses_allowed": 1}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_update_uses_allowed(self): + """Test updating just uses_allowed.""" + # Create new token using default values + token = self._new_token() + + # Should succeed with 1 + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 1}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should succeed with 0 (makes token invalid) + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 0) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should succeed with null + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": None}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should fail with a float + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 1.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a negative integer + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_update_expiry_time(self): + """Test updating just expiry_time.""" + # Create new token using default values + token = self._new_token() + new_expiry_time = self.clock.time_msec() + 1000000 + + # Should succeed with a time in the future + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": new_expiry_time}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) + self.assertIsNone(channel.json_body["uses_allowed"]) + + # Should succeed with null + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": None}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertIsNone(channel.json_body["uses_allowed"]) + + # Should fail with a time in the past + past_time = self.clock.time_msec() - 10000 + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": past_time}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail a float + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": new_expiry_time + 0.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_update_both(self): + """Test updating both uses_allowed and expiry_time.""" + # Create new token using default values + token = self._new_token() + new_expiry_time = self.clock.time_msec() + 1000000 + + data = { + "uses_allowed": 1, + "expiry_time": new_expiry_time, + } + + channel = self.make_request( + "PUT", + self.url + "/" + token, + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) + + def test_update_invalid_type(self): + """Test using invalid types doesn't work.""" + # Create new token using default values + token = self._new_token() + + data = { + "uses_allowed": False, + "expiry_time": "1626430124000", + } + + channel = self.make_request( + "PUT", + self.url + "/" + token, + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # DELETING + + def test_delete_no_auth(self): + """Try to delete a token without authentication.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_delete_requester_not_admin(self): + """Try to delete a token while not an admin.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_delete_non_existent(self): + """Try to delete a token that doesn't exist.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_delete(self): + """Test deleting a token.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "DELETE", + self.url + "/" + token, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # GETTING ONE + + def test_get_no_auth(self): + """Try to get a token without authentication.""" + channel = self.make_request( + "GET", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_get_requester_not_admin(self): + """Try to get a token while not an admin.""" + channel = self.make_request( + "GET", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_get_non_existent(self): + """Try to get a token that doesn't exist.""" + channel = self.make_request( + "GET", + self.url + "/1234", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_get(self): + """Test getting a token.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "GET", + self.url + "/" + token, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["token"], token) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + # LISTING + + def test_list_no_auth(self): + """Try to list tokens without authentication.""" + channel = self.make_request("GET", self.url, {}) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_list_requester_not_admin(self): + """Try to list tokens while not an admin.""" + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_list_all(self): + """Test listing all tokens.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["registration_tokens"]), 1) + token_info = channel.json_body["registration_tokens"][0] + self.assertEqual(token_info["token"], token) + self.assertIsNone(token_info["uses_allowed"]) + self.assertIsNone(token_info["expiry_time"]) + self.assertEqual(token_info["pending"], 0) + self.assertEqual(token_info["completed"], 0) + + def test_list_invalid_query_parameter(self): + """Test with `valid` query parameter not `true` or `false`.""" + channel = self.make_request( + "GET", + self.url + "?valid=x", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + def _test_list_query_parameter(self, valid: str): + """Helper used to test both valid=true and valid=false.""" + # Create 2 valid and 2 invalid tokens. + now = self.hs.get_clock().time_msec() + # Create always valid token + valid1 = self._new_token() + # Create token that hasn't been used up + valid2 = self._new_token(uses_allowed=1) + # Create token that has expired + invalid1 = self._new_token(expiry_time=now - 10000) + # Create token that has been used up but hasn't expired + invalid2 = self._new_token( + uses_allowed=2, + pending=1, + completed=1, + expiry_time=now + 1000000, + ) + + if valid == "true": + tokens = [valid1, valid2] + else: + tokens = [invalid1, invalid2] + + channel = self.make_request( + "GET", + self.url + "?valid=" + valid, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["registration_tokens"]), 2) + token_info_1 = channel.json_body["registration_tokens"][0] + token_info_2 = channel.json_body["registration_tokens"][1] + self.assertIn(token_info_1["token"], tokens) + self.assertIn(token_info_2["token"], tokens) + + def test_list_valid(self): + """Test listing just valid tokens.""" + self._test_list_query_parameter(valid="true") + + def test_list_invalid(self): + """Test listing just invalid tokens.""" + self._test_list_query_parameter(valid="false") diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index fecda037a5..9f3ab2c985 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync +from synapse.storage._base import db_to_json from tests import unittest from tests.unittest import override_config @@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config({"registration_requires_token": True}) + def test_POST_registration_requires_token(self): + username = "kermit" + device_id = "frogfone" + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + params = { + "username": username, + "password": "monkey", + "device_id": device_id, + } + + # Request without auth to get flows and session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + # Synapse adds a dummy stage to differentiate flows where otherwise one + # flow would be a subset of another flow. + self.assertCountEqual( + [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]], + (f["stages"] for f in flows), + ) + session = channel.json_body["session"] + + # Do the registration token stage and check it has completed + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + self.assertEquals(channel.result["code"], b"401", channel.result) + completed = channel.json_body["completed"] + self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) + + # Do the m.login.dummy stage and check registration was successful + params["auth"] = { + "type": LoginType.DUMMY, + "session": session, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + det_data = { + "user_id": f"@{username}:{self.hs.hostname}", + "home_server": self.hs.hostname, + "device_id": device_id, + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + # Check the `completed` counter has been incremented and pending is 0 + res = self.get_success( + store.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) + ) + self.assertEquals(res["completed"], 1) + self.assertEquals(res["pending"], 0) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_invalid(self): + params = { + "username": "kermit", + "password": "monkey", + } + # Request without auth to get session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Test with token param missing (invalid) + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "session": session, + } + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM) + self.assertEquals(channel.json_body["completed"], []) + + # Test with non-string (invalid) + params["auth"]["token"] = 1234 + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM) + self.assertEquals(channel.json_body["completed"], []) + + # Test with unknown token (invalid) + params["auth"]["token"] = "1234" + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_limit_uses(self): + token = "abcd" + store = self.hs.get_datastore() + # Create token that can be used once + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + params1 = {"username": "bert", "password": "monkey"} + params2 = {"username": "ernie", "password": "monkey"} + # Do 2 requests without auth to get two session IDs + channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + session1 = channel1.json_body["session"] + channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + session2 = channel2.json_body["session"] + + # Use token with session1 and check `pending` is 1 + params1["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session1, + } + self.make_request(b"POST", self.url, json.dumps(params1)) + # Repeat request to make sure pending isn't increased again + self.make_request(b"POST", self.url, json.dumps(params1)) + pending = self.get_success( + store.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + ) + self.assertEquals(pending, 1) + + # Check auth fails when using token with session2 + params2["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session2, + } + channel = self.make_request(b"POST", self.url, json.dumps(params2)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + # Complete registration with session1 + params1["auth"]["type"] = LoginType.DUMMY + self.make_request(b"POST", self.url, json.dumps(params1)) + # Check pending=0 and completed=1 + res = self.get_success( + store.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) + ) + self.assertEquals(res["pending"], 0) + self.assertEquals(res["completed"], 1) + + # Check auth still fails when using token with session2 + channel = self.make_request(b"POST", self.url, json.dumps(params2)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_expiry(self): + token = "abcd" + now = self.hs.get_clock().time_msec() + store = self.hs.get_datastore() + # Create token that expired yesterday + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": now - 24 * 60 * 60 * 1000, + }, + ) + ) + params = {"username": "kermit", "password": "monkey"} + # Request without auth to get session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Check authentication fails with expired token + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + # Update token so it expires tomorrow + self.get_success( + store.db_pool.simple_update_one( + "registration_tokens", + keyvalues={"token": token}, + updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000}, + ) + ) + + # Check authentication succeeds + channel = self.make_request(b"POST", self.url, json.dumps(params)) + completed = channel.json_body["completed"] + self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_session_expiry(self): + """Test `pending` is decremented when an uncompleted session expires.""" + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + # Do 2 requests without auth to get two session IDs + params1 = {"username": "bert", "password": "monkey"} + params2 = {"username": "ernie", "password": "monkey"} + channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + session1 = channel1.json_body["session"] + channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + session2 = channel2.json_body["session"] + + # Use token with both sessions + params1["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session1, + } + self.make_request(b"POST", self.url, json.dumps(params1)) + + params2["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session2, + } + self.make_request(b"POST", self.url, json.dumps(params2)) + + # Complete registration with session1 + params1["auth"]["type"] = LoginType.DUMMY + self.make_request(b"POST", self.url, json.dumps(params1)) + + # Check `result` of registration token stage for session1 is `True` + result1 = self.get_success( + store.db_pool.simple_select_one_onecol( + "ui_auth_sessions_credentials", + keyvalues={ + "session_id": session1, + "stage_type": LoginType.REGISTRATION_TOKEN, + }, + retcol="result", + ) + ) + self.assertTrue(db_to_json(result1)) + + # Check `result` for session2 is the token used + result2 = self.get_success( + store.db_pool.simple_select_one_onecol( + "ui_auth_sessions_credentials", + keyvalues={ + "session_id": session2, + "stage_type": LoginType.REGISTRATION_TOKEN, + }, + retcol="result", + ) + ) + self.assertEquals(db_to_json(result2), token) + + # Delete both sessions (mimics expiry) + self.get_success( + store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) + ) + + # Check pending is now 0 + pending = self.get_success( + store.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + ) + self.assertEquals(pending, 0) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_session_expiry_deleted_token(self): + """Test session expiry doesn't break when the token is deleted. + + 1. Start but don't complete UIA with a registration token + 2. Delete the token from the database + 3. Expire the session + """ + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + # Do request without auth to get a session ID + params = {"username": "kermit", "password": "monkey"} + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Use token + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + self.make_request(b"POST", self.url, json.dumps(params)) + + # Delete token + self.get_success( + store.db_pool.simple_delete_one( + "registration_tokens", + keyvalues={"token": token}, + ) + ) + + # Delete session (mimics expiry) + self.get_success( + store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) + ) + def test_advertised_flows(self): channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) @@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) self.assertLessEqual(res, now_ms + self.validity_period) + + +class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): + servlets = [register.register_servlets] + url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity" + + def default_config(self): + config = super().default_config() + config["registration_requires_token"] = True + return config + + def test_GET_token_valid(self): + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEquals(channel.json_body["valid"], True) + + def test_GET_token_invalid(self): + token = "1234" + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEquals(channel.json_body["valid"], False) + + @override_config( + {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} + ) + def test_GET_ratelimiting(self): + token = "1234" + + for i in range(0, 6): + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) -- cgit 1.5.1