From 42d3a28d8bb8c08e9e0d00a2e247cbbddb1a155c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 11 Jan 2021 17:15:54 +0100 Subject: Removes unnecessary declarations in the tests for the admin API. (#9063) --- tests/rest/admin/test_admin.py | 3 --- tests/rest/admin/test_event_reports.py | 4 ---- tests/rest/admin/test_media.py | 2 -- tests/rest/admin/test_room.py | 2 -- tests/rest/admin/test_statistics.py | 1 - tests/rest/admin/test_user.py | 5 ----- 6 files changed, 17 deletions(-) (limited to 'tests/rest') diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 0504cd187e..586b877bda 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -58,8 +58,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -156,7 +154,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.hs = hs # Allow for uploading and downloading to/from the media repo self.media_repo = hs.get_media_repository_resource() diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index aa389df12f..d0090faa4f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -32,8 +32,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -371,8 +369,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index c2b998cdae..51a7731693 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -35,7 +35,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.handler = hs.get_device_handler() self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -181,7 +180,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.handler = hs.get_device_handler() self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index fa620f97f3..a0f32c5512 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -605,8 +605,6 @@ class RoomTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - # Create user self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 73f8a8ec99..f48be3d65a 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -31,7 +31,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 9b2e4765f6..877fd2587b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1204,8 +1204,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1401,7 +1399,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -1868,8 +1865,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") -- cgit 1.5.1 From b161528fccac4bf17f7afa9438a75d796433194e Mon Sep 17 00:00:00 2001 From: David Teller Date: Mon, 11 Jan 2021 20:32:17 +0100 Subject: Also support remote users on the joined_rooms admin API. (#8948) For remote users, only the rooms which the server knows about are returned. Local users have all of their joined rooms returned. --- changelog.d/8948.feature | 1 + docs/admin_api/user_admin_api.rst | 4 +++ synapse/rest/admin/users.py | 7 ----- tests/rest/admin/test_user.py | 58 +++++++++++++++++++++++++++++++++++---- 4 files changed, 57 insertions(+), 13 deletions(-) create mode 100644 changelog.d/8948.feature (limited to 'tests/rest') diff --git a/changelog.d/8948.feature b/changelog.d/8948.feature new file mode 100644 index 0000000000..3b06cbfa22 --- /dev/null +++ b/changelog.d/8948.feature @@ -0,0 +1 @@ +Update `/_synapse/admin/v1/users//joined_rooms` to work for both local and remote users. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index e4d6f8203b..3115951e1f 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -337,6 +337,10 @@ A response body like the following is returned: "total": 2 } +The server returns the list of rooms of which the user and the server +are member. If the user is local, all the rooms of which the user is +member are returned. + **Parameters** The following parameters should be set in the URL: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 6658c2da56..f8a73e7d9d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -714,13 +714,6 @@ class UserMembershipRestServlet(RestServlet): async def on_GET(self, request, user_id): await assert_requester_is_admin(self.auth, request) - if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only lookup local users") - - user = await self.store.get_user_by_id(user_id) - if user is None: - raise NotFoundError("Unknown user") - room_ids = await self.store.get_rooms_for_user(user_id) ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} return 200, ret diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 877fd2587b..ad4588c1da 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -25,6 +25,7 @@ from mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError +from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync @@ -1234,24 +1235,26 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns an empty list """ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" channel = self.make_request("GET", url, access_token=self.admin_user_tok,) - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" channel = self.make_request("GET", url, access_token=self.admin_user_tok,) - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) def test_no_memberships(self): """ @@ -1282,6 +1285,49 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) + def test_get_rooms_with_nonlocal_user(self): + """ + Tests that a normal lookup for rooms is successful with a non-local user + """ + + other_user_tok = self.login("user", "pass") + event_builder_factory = self.hs.get_event_builder_factory() + event_creation_handler = self.hs.get_event_creation_handler() + storage = self.hs.get_storage() + + # Create two rooms, one with a local user only and one with both a local + # and remote user. + self.helper.create_room_as(self.other_user, tok=other_user_tok) + local_and_remote_room_id = self.helper.create_room_as( + self.other_user, tok=other_user_tok + ) + + # Add a remote user to the room. + builder = event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": "m.room.member", + "sender": "@joiner:remote_hs", + "state_key": "@joiner:remote_hs", + "room_id": local_and_remote_room_id, + "content": {"membership": "join"}, + }, + ) + + event, context = self.get_success( + event_creation_handler.create_new_client_event(builder) + ) + + self.get_success(storage.persistence.persist_event(event, context)) + + # Now get rooms + url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) + class PushersRestTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 7a2e9b549defe3f55531711a863183a33e7af83c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 12 Jan 2021 22:30:15 +0100 Subject: Remove user's avatar URL and displayname when deactivated. (#8932) This only applies if the user's data is to be erased. --- changelog.d/8932.feature | 1 + docs/admin_api/user_admin_api.rst | 21 +++ synapse/handlers/deactivate_account.py | 18 ++- synapse/handlers/profile.py | 8 +- synapse/rest/admin/users.py | 22 ++- synapse/rest/client/v2_alpha/account.py | 7 +- synapse/server.py | 2 +- synapse/storage/databases/main/profile.py | 2 +- tests/handlers/test_profile.py | 30 ++++ tests/rest/admin/test_user.py | 220 ++++++++++++++++++++++++++++++ tests/rest/client/v1/test_login.py | 5 +- tests/rest/client/v1/test_rooms.py | 6 +- tests/storage/test_profile.py | 26 ++++ 13 files changed, 351 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8932.feature (limited to 'tests/rest') diff --git a/changelog.d/8932.feature b/changelog.d/8932.feature new file mode 100644 index 0000000000..a1d17394d7 --- /dev/null +++ b/changelog.d/8932.feature @@ -0,0 +1 @@ +Remove a user's avatar URL and display name when deactivated with the Admin API. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 3115951e1f..b3d413cf57 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -98,6 +98,8 @@ Body parameters: - ``deactivated``, optional. If unspecified, deactivation state will be left unchanged on existing accounts and set to ``false`` for new accounts. + A user cannot be erased by deactivating with this API. For details on deactivating users see + `Deactivate Account <#deactivate-account>`_. If the user already exists then optional parameters default to the current value. @@ -248,6 +250,25 @@ server admin: see `README.rst `_. The erase parameter is optional and defaults to ``false``. An empty body may be passed for backwards compatibility. +The following actions are performed when deactivating an user: + +- Try to unpind 3PIDs from the identity server +- Remove all 3PIDs from the homeserver +- Delete all devices and E2EE keys +- Delete all access tokens +- Delete the password hash +- Removal from all rooms the user is a member of +- Remove the user from the user directory +- Reject all pending invites +- Remove all account validity information related to the user + +The following additional actions are performed during deactivation if``erase`` +is set to ``true``: + +- Remove the user's display name +- Remove the user's avatar URL +- Mark the user as erased + Reset password ============== diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index e808142365..c4a3b26a84 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID, create_requester +from synapse.types import Requester, UserID, create_requester from ._base import BaseHandler @@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler): self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() self._identity_handler = hs.get_identity_handler() + self._profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self._server_name = hs.hostname @@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler): self._account_validity_enabled = hs.config.account_validity.enabled async def deactivate_account( - self, user_id: str, erase_data: bool, id_server: Optional[str] = None + self, + user_id: str, + erase_data: bool, + requester: Requester, + id_server: Optional[str] = None, + by_admin: bool = False, ) -> bool: """Deactivate a user's account Args: user_id: ID of user to be deactivated erase_data: whether to GDPR-erase the user's data + requester: The user attempting to make this change. id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). + by_admin: Whether this change was made by an administrator. Returns: True if identity server supports removing threepids, otherwise False. @@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler): # Mark the user as erased, if they asked for that if erase_data: + user = UserID.from_string(user_id) + # Remove avatar URL from this user + await self._profile_handler.set_avatar_url(user, requester, "", by_admin) + # Remove displayname from this user + await self._profile_handler.set_displayname(user, requester, "", by_admin) + logger.info("Marking %s as erased", user_id) await self.store.mark_user_erased(user_id) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 36f9ee4b71..c02b951031 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler): 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) + avatar_url_to_set = new_avatar_url # type: Optional[str] + if new_avatar_url == "": + avatar_url_to_set = None + # Same like set_displayname if by_admin: requester = create_requester( target_user, authenticated_entity=requester.authenticated_entity ) - await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) + await self.store.set_profile_avatar_url( + target_user.localpart, avatar_url_to_set + ) if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(target_user.localpart) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f8a73e7d9d..f39e3d6d5c 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -244,7 +244,7 @@ class UserRestServletV2(RestServlet): if deactivate and not user["deactivated"]: await self.deactivate_account_handler.deactivate_account( - target_user.to_string(), False + target_user.to_string(), False, requester, by_admin=True ) elif not deactivate and user["deactivated"]: if "password" not in body: @@ -486,12 +486,22 @@ class WhoisRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): PATTERNS = admin_patterns("/deactivate/(?P[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.store = hs.get_datastore() + + async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + if not self.is_mine(UserID.from_string(target_user_id)): + raise SynapseError(400, "Can only deactivate local users") + + if not await self.store.get_user_by_id(target_user_id): + raise NotFoundError("User not found") - async def on_POST(self, request, target_user_id): - await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -501,10 +511,8 @@ class DeactivateAccountRestServlet(RestServlet): Codes.BAD_JSON, ) - UserID.from_string(target_user_id) - result = await self._deactivate_account_handler.deactivate_account( - target_user_id, erase + target_user_id, erase, requester, by_admin=True ) if result: id_server_unbind_result = "success" diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 3b50dc885f..65e68d641b 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -305,7 +305,7 @@ class DeactivateAccountRestServlet(RestServlet): # allow ASes to deactivate their own users if requester.app_service: await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase + requester.user.to_string(), erase, requester ) return 200, {} @@ -313,7 +313,10 @@ class DeactivateAccountRestServlet(RestServlet): requester, request, body, "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, id_server=body.get("id_server") + requester.user.to_string(), + erase, + requester, + id_server=body.get("id_server"), ) if result: id_server_unbind_result = "success" diff --git a/synapse/server.py b/synapse/server.py index 12da92b63c..d4c235cda5 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -501,7 +501,7 @@ class HomeServer(metaclass=abc.ABCMeta): return InitialSyncHandler(self) @cache_in_self - def get_profile_handler(self): + def get_profile_handler(self) -> ProfileHandler: return ProfileHandler(self) @cache_in_self diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 0e25ca3d7a..54ef0f1f54 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -82,7 +82,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_avatar_url( - self, user_localpart: str, new_avatar_url: str + self, user_localpart: str, new_avatar_url: Optional[str] ) -> None: await self.db_pool.simple_update_one( table="profiles", diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 919547556b..022943a10a 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -105,6 +105,21 @@ class ProfileTestCase(unittest.TestCase): "Frank", ) + # Set displayname to an empty string + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "" + ) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ) + ) + @defer.inlineCallbacks def test_set_my_name_if_disabled(self): self.hs.config.enable_set_displayname = False @@ -223,6 +238,21 @@ class ProfileTestCase(unittest.TestCase): "http://my.server/me.png", ) + # Set avatar to an empty string + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, synapse.types.create_requester(self.frank), "", + ) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), + ) + @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): self.hs.config.enable_set_avatar_url = False diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ad4588c1da..04599c2fcf 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -588,6 +588,200 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "bar", "user_id") +class DeactivateAccountTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass", displayname="User1") + self.other_user_token = self.login("user", "pass") + self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( + self.other_user + ) + self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote( + self.other_user + ) + + # set attributes for user + self.get_success( + self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + ) + self.get_success( + self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) + ) + + def test_no_auth(self): + """ + Try to deactivate users without authentication. + """ + channel = self.make_request("POST", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + url = "/_synapse/admin/v1/deactivate/@bob:test" + + channel = self.make_request("POST", url, access_token=self.other_user_token) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + channel = self.make_request( + "POST", url, access_token=self.other_user_token, content=b"{}", + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + def test_user_does_not_exist(self): + """ + Tests that deactivation for a user that does not exist returns a 404 + """ + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/deactivate/@unknown_person:test", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_erase_is_not_bool(self): + """ + If parameter `erase` is not boolean, return an error + """ + body = json.dumps({"erase": "False"}) + + channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that deactivation for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" + + channel = self.make_request("POST", url, access_token=self.admin_user_tok) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only deactivate local users", channel.json_body["error"]) + + def test_deactivate_user_erase_true(self): + """ + Test deactivating an user and set `erase` to `true` + """ + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + # Deactivate user + body = json.dumps({"erase": True}) + + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertIsNone(channel.json_body["avatar_url"]) + self.assertIsNone(channel.json_body["displayname"]) + + self._is_erased("@user:test", True) + + def test_deactivate_user_erase_false(self): + """ + Test deactivating an user and set `erase` to `false` + """ + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + # Deactivate user + body = json.dumps({"erase": False}) + + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + self._is_erased("@user:test", False) + + def _is_erased(self, user_id: str, expect: bool) -> None: + """Assert that the user is erased or not + """ + d = self.store.is_user_erased(user_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertFalse(self.get_success(d)) + + class UserRestTestCase(unittest.HomeserverTestCase): servlets = [ @@ -987,6 +1181,26 @@ class UserRestTestCase(unittest.HomeserverTestCase): Test deactivating another user. """ + # set attributes for user + self.get_success( + self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + ) + self.get_success( + self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) + ) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) + # Deactivate user body = json.dumps({"deactivated": True}) @@ -1000,6 +1214,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) # the user is deactivated, the threepid will be deleted # Get user @@ -1010,6 +1227,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) def test_change_name_deactivate_user_user_directory(self): diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 1d1dc9f8a2..f9b8011961 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -30,6 +30,7 @@ from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.types import create_requester from tests import unittest from tests.handlers.test_oidc import HAS_OIDC @@ -667,7 +668,9 @@ class CASTestCase(unittest.HomeserverTestCase): # Deactivate the account. self.get_success( - self.deactivate_account_handler.deactivate_account(self.user_id, False) + self.deactivate_account_handler.deactivate_account( + self.user_id, False, create_requester(self.user_id) + ) ) # Request the CAS ticket. diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 6105eac47c..d4e3165436 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -29,7 +29,7 @@ from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account -from synapse.types import JsonDict, RoomAlias, UserID +from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -1687,7 +1687,9 @@ class ContextTestCase(unittest.HomeserverTestCase): deactivate_account_handler = self.hs.get_deactivate_account_handler() self.get_success( - deactivate_account_handler.deactivate_account(self.user_id, erase_data=True) + deactivate_account_handler.deactivate_account( + self.user_id, True, create_requester(self.user_id) + ) ) # Invite another user in the room. This is needed because messages will be diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 3fd0a38cf5..ea63bd56b4 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -48,6 +48,19 @@ class ProfileStoreTestCase(unittest.TestCase): ), ) + # test set to None + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.u_frank.localpart, None) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.u_frank.localpart) + ) + ) + ) + @defer.inlineCallbacks def test_avatar_url(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) @@ -66,3 +79,16 @@ class ProfileStoreTestCase(unittest.TestCase): ) ), ) + + # test set to None + yield defer.ensureDeferred( + self.store.set_profile_avatar_url(self.u_frank.localpart, None) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.u_frank.localpart) + ) + ) + ) -- cgit 1.5.1 From 233c8b9fce99616631614d176de4612c29277884 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 13 Jan 2021 20:21:55 +0000 Subject: Add a test for UI-Auth-via-SSO (#9082) * Add complete test for UI-Auth-via-SSO. * review comments --- changelog.d/9082.feature | 1 + tests/rest/client/v1/utils.py | 196 ++++++++++++++++++++++++++------ tests/rest/client/v2_alpha/test_auth.py | 40 ++++++- tests/server.py | 32 +++++- 4 files changed, 227 insertions(+), 42 deletions(-) create mode 100644 changelog.d/9082.feature (limited to 'tests/rest') diff --git a/changelog.d/9082.feature b/changelog.d/9082.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9082.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 81b7f84360..85d1709ead 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -2,7 +2,7 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018-2019 New Vector Ltd -# Copyright 2019-2020 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,8 @@ import json import re import time import urllib.parse -from typing import Any, Dict, Optional +from html.parser import HTMLParser +from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple from mock import patch @@ -32,7 +33,7 @@ from twisted.web.server import Site from synapse.api.constants import Membership from synapse.types import JsonDict -from tests.server import FakeSite, make_request +from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import FakeResponse @@ -362,34 +363,94 @@ class RestHelper: the normal places. """ client_redirect_url = "https://x" + channel = self.auth_via_oidc(remote_user_id, client_redirect_url) - # first hit the redirect url (which will issue a cookie and state) + # expect a confirmation page + assert channel.code == 200 + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.text_body, + ) + assert m, channel.text_body + login_token = m.group(1) + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token and device id. channel = make_request( self.hs.get_reactor(), self.site, - "GET", - "/login/sso/redirect?redirectUrl=" + client_redirect_url, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, ) - # that will redirect to the OIDC IdP, but we skip that and go straight + assert channel.code == 200 + return channel.json_body + + def auth_via_oidc( + self, + remote_user_id: str, + client_redirect_url: Optional[str] = None, + ui_auth_session_id: Optional[str] = None, + ) -> FakeChannel: + """Perform an OIDC authentication flow via a mock OIDC provider. + + This can be used for either login or user-interactive auth. + + Starts by making a request to the relevant synapse redirect endpoint, which is + expected to serve a 302 to the OIDC provider. We then make a request to the + OIDC callback endpoint, intercepting the HTTP requests that will get sent back + to the OIDC provider. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + + Args: + remote_user_id: the remote id that the OIDC provider should present + client_redirect_url: for a login flow, the client redirect URL to pass to + the login redirect endpoint + ui_auth_session_id: if set, we will perform a UI Auth flow. The session id + of the UI auth. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + Note that the response code may be a 200, 302 or 400 depending on how things + went. + """ + + cookies = {} + + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + + # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" - # param that synapse passes to the IdP via query params, and the cookie that - # synapse passes to the client. - assert channel.code == 302 - oauth_uri = channel.headers.getRawHeaders("Location")[0] - params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query) - redirect_uri = "%s?%s" % ( + # param that synapse passes to the IdP via query params, as well as the cookie + # that synapse passes to the client. + + oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1) + assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + "unexpected SSO URI " + oauth_uri_path + ) + params = urllib.parse.parse_qs(oauth_uri_qs) + callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), ) - cookies = {} - for h in channel.headers.getRawHeaders("Set-Cookie"): - parts = h.split(";") - k, v = parts[0].split("=", maxsplit=1) - cookies[k] = v # before we hit the callback uri, stub out some methods in the http client so # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. expected_requests = [ # first we get a hit to the token endpoint, which we tell to return @@ -413,34 +474,97 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", - redirect_uri, + callback_uri, custom_headers=[ ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() ], ) + return channel - # expect a confirmation page - assert channel.code == 200 + def initiate_sso_login( + self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the login-via-sso redirect endpoint, and return the target - # fish the matrix login token out of the body of the confirmation page - m = re.search( - 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), - channel.result["body"].decode("utf-8"), - ) - assert m - login_token = m.group(1) + Assumes that exactly one SSO provider has been configured. Requires the login + servlet to be mounted. - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token and device id. + Args: + client_redirect_url: the client redirect URL to pass to the login redirect + endpoint + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets redirected to (ie, the SSO server) + """ + params = {} + if client_redirect_url: + params["redirectUrl"] = client_redirect_url + + # hit the redirect url (which will issue a cookie and state) channel = make_request( self.hs.get_reactor(), self.site, - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, + "GET", + "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), ) - assert channel.code == 200 - return channel.json_body + + assert channel.code == 302 + channel.extract_cookies(cookies) + return channel.headers.getRawHeaders("Location")[0] + + def initiate_sso_ui_auth( + self, ui_auth_session_id: str, cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the ui-auth-via-sso endpoint, and return the target + + Assumes that exactly one SSO provider has been configured. Requires the + AuthRestServlet to be mounted. + + Args: + ui_auth_session_id: the session id of the UI auth + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets linked to (ie, the SSO server) + """ + sso_redirect_endpoint = ( + "/_matrix/client/r0/auth/m.login.sso/fallback/web?" + + urllib.parse.urlencode({"session": ui_auth_session_id}) + ) + # hit the redirect url (which will issue a cookie and state) + channel = make_request( + self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint + ) + # that should serve a confirmation page + assert channel.code == 200, channel.text_body + channel.extract_cookies(cookies) + + # parse the confirmation page to fish out the link. + class ConfirmationPageParser(HTMLParser): + def __init__(self): + super().__init__() + + self.links = [] # type: List[str] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "a": + href = attr_dict["href"] + if href: + self.links.append(href) + + def error(_, message): + raise AssertionError(message) + + p = ConfirmationPageParser() + p.feed(channel.text_body) + p.close() + assert len(p.links) == 1, "not exactly one link in confirmation page" + oauth_uri = p.links[0] + return oauth_uri # an 'oidc_config' suitable for login_via_oidc. diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index bb91e0c331..5f6ca23b06 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector +# Copyright 2020-2021 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Union from twisted.internet.defer import succeed @@ -384,6 +384,44 @@ class UIAuthTests(unittest.HomeserverTestCase): # Note that *no auth* information is provided, not even a session iD! self.delete_device(self.user_tok, self.device_id, 200) + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_via_sso(self): + """Test a successful UI Auth flow via SSO + + This includes: + * hitting the UIA SSO redirect endpoint + * checking it serves a confirmation page which links to the OIDC provider + * calling back to the synapse oidc callback + * checking that the original operation succeeds + """ + + # log the user in + remote_user_id = UserID.from_string(self.user).localpart + login_resp = self.helper.login_via_oidc(remote_user_id) + self.assertEqual(login_resp["user_id"], self.user) + + # initiate a UI Auth process by attempting to delete the device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # check that SSO is offered + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + + # run the UIA-via-SSO flow + session_id = channel.json_body["session"] + channel = self.helper.auth_via_oidc( + remote_user_id=remote_user_id, ui_auth_session_id=session_id + ) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + + # and now the delete request should succeed. + self.delete_device( + self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}}, + ) + @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self): diff --git a/tests/server.py b/tests/server.py index 7d1ad362c4..5a1b66270f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -2,7 +2,7 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable, Iterable, Optional, Tuple, Union +from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union import attr from typing_extensions import Deque @@ -51,9 +51,21 @@ class FakeChannel: @property def json_body(self): - if not self.result: - raise Exception("No result yet.") - return json.loads(self.result["body"].decode("utf8")) + return json.loads(self.text_body) + + @property + def text_body(self) -> str: + """The body of the result, utf-8-decoded. + + Raises an exception if the request has not yet completed. + """ + if not self.is_finished: + raise Exception("Request not yet completed") + return self.result["body"].decode("utf8") + + def is_finished(self) -> bool: + """check if the response has been completely received""" + return self.result.get("done", False) @property def code(self): @@ -124,7 +136,7 @@ class FakeChannel: self._reactor.run() x = 0 - while not self.result.get("done"): + while not self.is_finished(): # If there's a producer, tell it to resume producing so we get content if self._producer: self._producer.resumeProducing() @@ -136,6 +148,16 @@ class FakeChannel: self._reactor.advance(0.1) + def extract_cookies(self, cookies: MutableMapping[str, str]) -> None: + """Process the contents of any Set-Cookie headers in the response + + Any cookines found are added to the given dict + """ + for h in self.headers.getRawHeaders("Set-Cookie"): + parts = h.split(";") + k, v = parts[0].split("=", maxsplit=1) + cookies[k] = v + class FakeSite: """ -- cgit 1.5.1 From 26d10331e5fa33507b680b45c44641e660a3adeb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 13 Jan 2021 12:35:23 +0000 Subject: Add a test for wrong user returned by SSO --- tests/rest/client/v2_alpha/test_auth.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) (limited to 'tests/rest') diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 5f6ca23b06..50630106ad 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -457,3 +457,30 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertIn({"stages": ["m.login.password"]}, flows) self.assertIn({"stages": ["m.login.sso"]}, flows) self.assertEqual(len(flows), 2) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_fails_for_incorrect_sso_user(self): + """If the user tries to authenticate with the wrong SSO user, they get an error + """ + # log the user in + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + # start a UI Auth flow by attempting to delete a device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + session_id = channel.json_body["session"] + + # do the OIDC auth, but auth as the wrong user + channel = self.helper.auth_via_oidc("wrong_user", ui_auth_session_id=session_id) + + # that should return a failure message + self.assertSubstring("We were unable to validate", channel.text_body) + + # ... and the delete op should now fail with a 403 + self.delete_device( + self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + ) -- cgit 1.5.1 From 0dd2649c127e4eb538dfbf0c879bd66c9ff1599c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 Jan 2021 13:45:13 +0000 Subject: Improve UsernamePickerTestCase (#9112) * make the OIDC bits of the test work at a higher level - via the REST api instead of poking the OIDCHandler directly. * Move it to test_login.py, where I think it fits better. --- changelog.d/9112.misc | 1 + tests/handlers/test_oidc.py | 120 +------------------------------- tests/rest/client/v1/test_login.py | 105 +++++++++++++++++++++++++++- tests/rest/client/v1/utils.py | 11 +-- tests/rest/client/v2_alpha/test_auth.py | 2 +- 5 files changed, 114 insertions(+), 125 deletions(-) create mode 100644 changelog.d/9112.misc (limited to 'tests/rest') diff --git a/changelog.d/9112.misc b/changelog.d/9112.misc new file mode 100644 index 0000000000..691f9d8b43 --- /dev/null +++ b/changelog.d/9112.misc @@ -0,0 +1 @@ +Improve `UsernamePickerTestCase`. diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 38ae8ca19e..02e21ed6ca 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,20 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -import re -from typing import Dict, Optional -from urllib.parse import parse_qs, urlencode, urlparse +from typing import Optional +from urllib.parse import parse_qs, urlparse from mock import ANY, Mock, patch import pymacaroons -from twisted.web.resource import Resource - -from synapse.api.errors import RedirectException from synapse.handlers.sso import MappingException -from synapse.rest.client.v1 import login -from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.server import HomeServer from synapse.types import UserID @@ -856,116 +850,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) -class UsernamePickerTestCase(HomeserverTestCase): - if not HAS_OIDC: - skip = "requires OIDC" - - servlets = [login.register_servlets] - - def default_config(self): - config = super().default_config() - config["public_baseurl"] = BASE_URL - oidc_config = { - "enabled": True, - "client_id": CLIENT_ID, - "client_secret": CLIENT_SECRET, - "issuer": ISSUER, - "scopes": SCOPES, - "user_mapping_provider": { - "config": {"display_name_template": "{{ user.displayname }}"} - }, - } - - # Update this config with what's in the default config so that - # override_config works as expected. - oidc_config.update(config.get("oidc_config", {})) - config["oidc_config"] = oidc_config - - # whitelist this client URI so we redirect straight to it rather than - # serving a confirmation page - config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} - return config - - def create_resource_dict(self) -> Dict[str, Resource]: - d = super().create_resource_dict() - d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) - return d - - def test_username_picker(self): - """Test the happy path of a username picker flow.""" - client_redirect_url = "https://whitelisted.client" - - # first of all, mock up an OIDC callback to the OidcHandler, which should - # raise a RedirectException - userinfo = {"sub": "tester", "displayname": "Jonny"} - f = self.get_failure( - _make_callback_with_userinfo( - self.hs, userinfo, client_redirect_url=client_redirect_url - ), - RedirectException, - ) - - # check the Location and cookies returned by the RedirectException - self.assertEqual(f.value.location, b"/_synapse/client/pick_username") - cookieheader = f.value.cookies[0] - regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);") - m = regex.search(cookieheader) - if not m: - self.fail("cookie header %s does not match %s" % (cookieheader, regex)) - - # introspect the sso handler a bit to check that the username mapping session - # looks ok. - session_id = m.group(1).decode("ascii") - username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions - self.assertIn( - session_id, username_mapping_sessions, "session id not found in map" - ) - session = username_mapping_sessions[session_id] - self.assertEqual(session.remote_user_id, "tester") - self.assertEqual(session.display_name, "Jonny") - self.assertEqual(session.client_redirect_url, client_redirect_url) - - # the expiry time should be about 15 minutes away - expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) - self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) - - # Now, submit a username to the username picker, which should serve a redirect - # back to the client - submit_path = f.value.location + b"/submit" - content = urlencode({b"username": b"bobby"}).encode("utf8") - chan = self.make_request( - "POST", - path=submit_path, - content=content, - content_is_form=True, - custom_headers=[ - ("Cookie", cookieheader), - # old versions of twisted don't do form-parsing without a valid - # content-length header. - ("Content-Length", str(len(content))), - ], - ) - self.assertEqual(chan.code, 302, chan.result) - location_headers = chan.headers.getRawHeaders("Location") - # ensure that the returned location starts with the requested redirect URL - self.assertEqual( - location_headers[0][: len(client_redirect_url)], client_redirect_url - ) - - # fish the login token out of the returned redirect uri - parts = urlparse(location_headers[0]) - query = parse_qs(parts.query) - login_token = query["loginToken"][0] - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token, mxid, and device id. - chan = self.make_request( - "POST", "/login", content={"type": "m.login.token", "token": login_token}, - ) - self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@bobby:test") - - async def _make_callback_with_userinfo( hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" ) -> None: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index f9b8011961..73a009efd1 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -17,6 +17,7 @@ import time import urllib.parse from html.parser import HTMLParser from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from urllib.parse import parse_qs, urlencode, urlparse from mock import Mock @@ -30,13 +31,14 @@ from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.types import create_requester from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG -from tests.unittest import override_config, skip_unless +from tests.unittest import HomeserverTestCase, override_config, skip_unless try: import jwt @@ -1060,3 +1062,104 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"401", channel.result) + + +@skip_unless(HAS_OIDC, "requires OIDC") +class UsernamePickerTestCase(HomeserverTestCase): + """Tests for the username picker flow of SSO login""" + + servlets = [login.register_servlets] + + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + + config["oidc_config"] = {} + config["oidc_config"].update(TEST_OIDC_CONFIG) + config["oidc_config"]["user_mapping_provider"] = { + "config": {"display_name_template": "{{ user.displayname }}"} + } + + # whitelist this client URI so we redirect straight to it rather than + # serving a confirmation page + config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + from synapse.rest.oidc import OIDCResource + + d = super().create_resource_dict() + d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + d["/_synapse/oidc"] = OIDCResource(self.hs) + return d + + def test_username_picker(self): + """Test the happy path of a username picker flow.""" + client_redirect_url = "https://whitelisted.client" + + # do the start of the login flow + channel = self.helper.auth_via_oidc( + {"sub": "tester", "displayname": "Jonny"}, client_redirect_url + ) + + # that should redirect to the username picker + self.assertEqual(channel.code, 302, channel.result) + picker_url = channel.headers.getRawHeaders("Location")[0] + self.assertEqual(picker_url, "/_synapse/client/pick_username") + + # ... with a username_mapping_session cookie + cookies = {} # type: Dict[str,str] + channel.extract_cookies(cookies) + self.assertIn("username_mapping_session", cookies) + session_id = cookies["username_mapping_session"] + + # introspect the sso handler a bit to check that the username mapping session + # looks ok. + username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions + self.assertIn( + session_id, username_mapping_sessions, "session id not found in map", + ) + session = username_mapping_sessions[session_id] + self.assertEqual(session.remote_user_id, "tester") + self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.client_redirect_url, client_redirect_url) + + # the expiry time should be about 15 minutes away + expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) + self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + + # Now, submit a username to the username picker, which should serve a redirect + # back to the client + submit_path = picker_url + "/submit" + content = urlencode({b"username": b"bobby"}).encode("utf8") + chan = self.make_request( + "POST", + path=submit_path, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", "username_mapping_session=" + session_id), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location starts with the requested redirect URL + self.assertEqual( + location_headers[0][: len(client_redirect_url)], client_redirect_url + ) + + # fish the login token out of the returned redirect uri + parts = urlparse(location_headers[0]) + query = parse_qs(parts.query) + login_token = query["loginToken"][0] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 85d1709ead..c6647dbe08 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -363,10 +363,10 @@ class RestHelper: the normal places. """ client_redirect_url = "https://x" - channel = self.auth_via_oidc(remote_user_id, client_redirect_url) + channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) # expect a confirmation page - assert channel.code == 200 + assert channel.code == 200, channel.result # fish the matrix login token out of the body of the confirmation page m = re.search( @@ -390,7 +390,7 @@ class RestHelper: def auth_via_oidc( self, - remote_user_id: str, + user_info_dict: JsonDict, client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, ) -> FakeChannel: @@ -411,7 +411,8 @@ class RestHelper: the normal places. Args: - remote_user_id: the remote id that the OIDC provider should present + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. client_redirect_url: for a login flow, the client redirect URL to pass to the login redirect endpoint ui_auth_session_id: if set, we will perform a UI Auth flow. The session id @@ -457,7 +458,7 @@ class RestHelper: # a dummy OIDC access token ("https://issuer.test/token", {"access_token": "TEST"}), # and then one to the user_info endpoint, which returns our remote user id. - ("https://issuer.test/userinfo", {"sub": remote_user_id}), + ("https://issuer.test/userinfo", user_info_dict), ] async def mock_req(method: str, uri: str, data=None, headers=None): diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 50630106ad..3e8661f9b9 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -411,7 +411,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # run the UIA-via-SSO flow session_id = channel.json_body["session"] channel = self.helper.auth_via_oidc( - remote_user_id=remote_user_id, ui_auth_session_id=session_id + {"sub": remote_user_id}, ui_auth_session_id=session_id ) # that should serve a confirmation page -- cgit 1.5.1 From 3e4cdfe5d9b6152b9a0b7f7ad22623c2bf19f7ff Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 Jan 2021 11:18:09 -0500 Subject: Add an admin API endpoint to protect media. (#9086) Protecting media stops it from being quarantined when e.g. all media in a room is quarantined. This is useful for sticker packs and other media that is uploaded by server administrators, but used by many people. --- changelog.d/9086.feature | 1 + docs/admin_api/media_admin_api.md | 24 +++++++++++++++ synapse/rest/admin/media.py | 64 ++++++++++++++++++++++++++++++--------- tests/rest/admin/test_admin.py | 8 +++-- 4 files changed, 79 insertions(+), 18 deletions(-) create mode 100644 changelog.d/9086.feature (limited to 'tests/rest') diff --git a/changelog.d/9086.feature b/changelog.d/9086.feature new file mode 100644 index 0000000000..3e678e24d5 --- /dev/null +++ b/changelog.d/9086.feature @@ -0,0 +1 @@ +Add an admin API for protecting local media from quarantine. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index dfb8c5d751..90faeaaef0 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -4,6 +4,7 @@ * [Quarantining media by ID](#quarantining-media-by-id) * [Quarantining media in a room](#quarantining-media-in-a-room) * [Quarantining all media of a user](#quarantining-all-media-of-a-user) + * [Protecting media from being quarantined](#protecting-media-from-being-quarantined) - [Delete local media](#delete-local-media) * [Delete a specific local media](#delete-a-specific-local-media) * [Delete local media by date or size](#delete-local-media-by-date-or-size) @@ -123,6 +124,29 @@ The following fields are returned in the JSON response body: * `num_quarantined`: integer - The number of media items successfully quarantined +## Protecting media from being quarantined + +This API protects a single piece of local media from being quarantined using the +above APIs. This is useful for sticker packs and other shared media which you do +not want to get quarantined, especially when +[quarantining media in a room](#quarantining-media-in-a-room). + +Request: + +``` +POST /_synapse/admin/v1/media/protect/ + +{} +``` + +Where `media_id` is in the form of `abcdefg12345...`. + +Response: + +```json +{} +``` + # Delete local media This API deletes the *local* media from the disk of your own server. This includes any local thumbnails and copies of media downloaded from diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index c82b4f87d6..8720b1401f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -15,6 +15,9 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.http import Request from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_boolean, parse_integer @@ -23,6 +26,10 @@ from synapse.rest.admin._base import ( assert_requester_is_admin, assert_user_is_admin, ) +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet): admin_patterns("/quarantine_media/(?P[^/]+)") ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, room_id: str): + async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet): PATTERNS = admin_patterns("/user/(?P[^/]+)/media/quarantine") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, user_id: str): + async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet): "/media/quarantine/(?P[^/]+)/(?P[^/]+)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, server_name: str, media_id: str): + async def on_POST( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet): return 200, {} +class ProtectMediaByID(RestServlet): + """Protect local media from being quarantined. + """ + + PATTERNS = admin_patterns("/media/protect/(?P[^/]+)") + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Protecting local media by ID: %s", media_id) + + # Quarantine this media id + await self.store.mark_local_media_as_safe(media_id) + + return 200, {} + + class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ PATTERNS = admin_patterns("/room/(?P[^/]+)/media") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: @@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet): class PurgeMediaCacheRestServlet(RestServlet): PATTERNS = admin_patterns("/purge_media_cache") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.media_repository = hs.get_media_repository() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]+)/(?P[^/]+)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_DELETE(self, request, server_name: str, media_id: str): + async def on_DELETE( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) if self.server_name != server_name: @@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]+)/delete") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_POST(self, request, server_name: str): + async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet): return 200, {"deleted_media": deleted_media, "total": total} -def register_servlets_for_media_repo(hs, http_server): +def register_servlets_for_media_repo(hs: "HomeServer", http_server): """ Media repo specific APIs. """ @@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server): QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaByID(hs).register(http_server) QuarantineMediaByUser(hs).register(http_server) + ProtectMediaByID(hs).register(http_server) ListMediaInRoom(hs).register(http_server) DeleteMediaByID(hs).register(http_server) DeleteMediaByDateSize(hs).register(http_server) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 586b877bda..9d22c04073 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -153,8 +153,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - # Allow for uploading and downloading to/from the media repo self.media_repo = hs.get_media_repository_resource() self.download_resource = self.media_repo.children[b"download"] @@ -428,7 +426,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Mark the second item as safe from quarantine. _, media_id_2 = server_and_media_id_2.split("/") - self.get_success(self.store.mark_local_media_as_safe(media_id_2)) + # Quarantine the media + url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) + channel = self.make_request("POST", url, access_token=admin_user_tok) + self.pump(1.0) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Quarantine all media by this user url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( -- cgit 1.5.1 From b5dea8702d9b799a15d5b8fb90e82497b8c17f63 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jan 2021 18:03:33 +0000 Subject: Fix test failure due to bad merge 0dd2649c1 (#9112) changed the signature of `auth_via_oidc`. Meanwhile, 26d10331e (#9091) introduced a new test which relied on the old signature of `auth_via_oidc`. The two branches were never tested together until they landed in develop. --- tests/rest/client/v2_alpha/test_auth.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests/rest') diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 3e8661f9b9..a6488a3d29 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -475,7 +475,9 @@ class UIAuthTests(unittest.HomeserverTestCase): session_id = channel.json_body["session"] # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc("wrong_user", ui_auth_session_id=session_id) + channel = self.helper.auth_via_oidc( + {"sub": "wrong_user"}, ui_auth_session_id=session_id + ) # that should return a failure message self.assertSubstring("We were unable to validate", channel.text_body) -- cgit 1.5.1 From 02070c69faa47bf6aef280939c2d5f32cbcb9f25 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 18 Jan 2021 14:52:49 +0000 Subject: Fix bugs in handling clientRedirectUrl, and improve OIDC tests (#9127, #9128) * Factor out a common TestHtmlParser Looks like I'm doing this in a few different places. * Improve OIDC login test Complete the OIDC login flow, rather than giving up halfway through. * Ensure that OIDC login works with multiple OIDC providers * Fix bugs in handling clientRedirectUrl - don't drop duplicate query-params, or params with no value - allow utf-8 in query-params --- changelog.d/9127.feature | 1 + changelog.d/9128.bugfix | 1 + synapse/handlers/auth.py | 4 +- synapse/handlers/oidc_handler.py | 2 +- synapse/rest/synapse/client/pick_idp.py | 4 +- tests/rest/client/v1/test_login.py | 146 ++++++++++++++++++++------------ tests/rest/client/v1/utils.py | 62 ++++++++------ tests/server.py | 2 +- tests/test_utils/html_parsers.py | 53 ++++++++++++ 9 files changed, 189 insertions(+), 86 deletions(-) create mode 100644 changelog.d/9127.feature create mode 100644 changelog.d/9128.bugfix create mode 100644 tests/test_utils/html_parsers.py (limited to 'tests/rest') diff --git a/changelog.d/9127.feature b/changelog.d/9127.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9127.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/changelog.d/9128.bugfix b/changelog.d/9128.bugfix new file mode 100644 index 0000000000..f87b9fb9aa --- /dev/null +++ b/changelog.d/9128.bugfix @@ -0,0 +1 @@ +Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 18cd2b62f0..0e98db22b3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1504,8 +1504,8 @@ class AuthHandler(BaseHandler): @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): url_parts = list(urllib.parse.urlparse(url)) - query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({param_name: param}) + query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True) + query.append((param_name, param)) url_parts[4] = urllib.parse.urlencode(query) return urllib.parse.urlunparse(url_parts) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 5e5fda7b2f..ba686d74b2 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -85,7 +85,7 @@ class OidcHandler: self._token_generator = OidcSessionTokenGenerator(hs) self._providers = { p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs - } + } # type: Dict[str, OidcProvider] async def load_metadata(self) -> None: """Validate the config and load the metadata from the remote endpoint. diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py index e5b720bbca..9550b82998 100644 --- a/synapse/rest/synapse/client/pick_idp.py +++ b/synapse/rest/synapse/client/pick_idp.py @@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource): self._server_name = hs.hostname async def _async_render_GET(self, request: SynapseRequest) -> None: - client_redirect_url = parse_string(request, "redirectUrl", required=True) + client_redirect_url = parse_string( + request, "redirectUrl", required=True, encoding="utf-8" + ) idp = parse_string(request, "idp", required=False) # if we need to pick an IdP, do so diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 73a009efd1..2d25490374 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,9 +15,8 @@ import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from urllib.parse import parse_qs, urlencode, urlparse +from typing import Any, Dict, Union +from urllib.parse import urlencode from mock import Mock @@ -38,6 +37,7 @@ from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless try: @@ -69,6 +69,12 @@ TEST_SAML_METADATA = """ LOGIN_URL = b"/_matrix/client/r0/login" TEST_URL = b"/_matrix/client/r0/account/whoami" +# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + +TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' + +# the query params in TEST_CLIENT_REDIRECT_URL +EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -389,23 +395,44 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): }, } + # default OIDC provider config["oidc_config"] = TEST_OIDC_CONFIG + # additional OIDC providers + config["oidc_providers"] = [ + { + "idp_id": "idp1", + "idp_name": "IDP1", + "discover": False, + "issuer": "https://issuer1", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://issuer1/auth", + "token_endpoint": "https://issuer1/token", + "userinfo_endpoint": "https://issuer1/userinfo", + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.sub }}"} + }, + } + ] return config def create_resource_dict(self) -> Dict[str, Resource]: + from synapse.rest.oidc import OIDCResource + d = super().create_resource_dict() d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) + d["/_synapse/oidc"] = OIDCResource(self.hs) return d def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" - client_redirect_url = "https://x?" - # first hit the redirect url, which should redirect to our idp picker channel = self.make_request( "GET", - "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, + "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), ) self.assertEqual(channel.code, 302, channel.result) uri = channel.headers.getRawHeaders("Location")[0] @@ -415,46 +442,22 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # parse the form to check it has fields assumed elsewhere in this class - class FormPageParser(HTMLParser): - def __init__(self): - super().__init__() - - # the values of the hidden inputs: map from name to value - self.hiddens = {} # type: Dict[str, Optional[str]] - - # the values of the radio buttons - self.radios = [] # type: List[Optional[str]] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "input": - if attr_dict["type"] == "radio" and attr_dict["name"] == "idp": - self.radios.append(attr_dict["value"]) - elif attr_dict["type"] == "hidden": - input_name = attr_dict["name"] - assert input_name - self.hiddens[input_name] = attr_dict["value"] - - def error(_, message): - self.fail(message) - - p = FormPageParser() + p = TestHtmlParser() p.feed(channel.result["body"].decode("utf-8")) p.close() - self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) + self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"]) - self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) + self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_cas(self): """If CAS is chosen, should redirect to the CAS server""" - client_redirect_url = "https://x?" channel = self.make_request( "GET", - "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=cas", shorthand=False, ) self.assertEqual(channel.code, 302, channel.result) @@ -470,16 +473,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): service_uri = cas_uri_params["service"][0] _, service_uri_query = service_uri.split("?", 1) service_uri_params = urllib.parse.parse_qs(service_uri_query) - self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) + self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_saml(self): """If SAML is chosen, should redirect to the SAML server""" - client_redirect_url = "https://x?" - channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) self.assertEqual(channel.code, 302, channel.result) @@ -492,16 +493,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # the RelayState is used to carry the client redirect url saml_uri_params = urllib.parse.parse_qs(saml_uri_query) relay_state_param = saml_uri_params["RelayState"][0] - self.assertEqual(relay_state_param, client_redirect_url) + self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_oidc(self): + def test_login_via_oidc(self): """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - client_redirect_url = "https://x?" + # pick the default OIDC provider channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=oidc", ) self.assertEqual(channel.code, 302, channel.result) @@ -521,9 +522,41 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) self.assertEqual( self._get_value_from_macaroon(macaroon, "client_redirect_url"), - client_redirect_url, + TEST_CLIENT_REDIRECT_URL, ) + channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + self.assertTrue( + channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html") + ) + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + + # ... which should contain our redirect link + self.assertEqual(len(p.links), 1) + path, query = p.links[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + login_token = params[2][1] + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@user1:test") + def test_multi_sso_redirect_to_unknown(self): """An unknown IdP should cause a 400""" channel = self.make_request( @@ -1082,7 +1115,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # whitelist this client URI so we redirect straight to it rather than # serving a confirmation page - config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + config["sso"] = {"client_whitelist": ["https://x"]} return config def create_resource_dict(self) -> Dict[str, Resource]: @@ -1095,11 +1128,10 @@ class UsernamePickerTestCase(HomeserverTestCase): def test_username_picker(self): """Test the happy path of a username picker flow.""" - client_redirect_url = "https://whitelisted.client" # do the start of the login flow channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, client_redirect_url + {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL ) # that should redirect to the username picker @@ -1122,7 +1154,7 @@ class UsernamePickerTestCase(HomeserverTestCase): session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") self.assertEqual(session.display_name, "Jonny") - self.assertEqual(session.client_redirect_url, client_redirect_url) + self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) # the expiry time should be about 15 minutes away expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) @@ -1146,15 +1178,19 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") - # ensure that the returned location starts with the requested redirect URL - self.assertEqual( - location_headers[0][: len(client_redirect_url)], client_redirect_url + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") # fish the login token out of the returned redirect uri - parts = urlparse(location_headers[0]) - query = parse_qs(parts.query) - login_token = query["loginToken"][0] + login_token = params[2][1] # finally, submit the matrix login token to the login API, which gives us our # matrix access token, mxid, and device id. diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index c6647dbe08..b1333df82d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -20,8 +20,7 @@ import json import re import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple +from typing import Any, Dict, Mapping, MutableMapping, Optional from mock import patch @@ -35,6 +34,7 @@ from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import FakeResponse +from tests.test_utils.html_parsers import TestHtmlParser @attr.s @@ -440,10 +440,36 @@ class RestHelper: # param that synapse passes to the IdP via query params, as well as the cookie # that synapse passes to the client. - oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1) + oauth_uri_path, _ = oauth_uri.split("?", 1) assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( "unexpected SSO URI " + oauth_uri_path ) + return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + + def complete_oidc_auth( + self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, + ) -> FakeChannel: + """Mock out an OIDC authentication flow + + Assumes that an OIDC auth has been initiated by one of initiate_sso_login or + initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to + Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get + sent back to the OIDC provider. + + Requires the OIDC callback resource to be mounted at the normal place. + + Args: + oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, + from initiate_sso_login or initiate_sso_ui_auth). + cookies: the cookies set by synapse's redirect endpoint, which will be + sent back to the callback endpoint. + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + """ + _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, @@ -456,9 +482,9 @@ class RestHelper: expected_requests = [ # first we get a hit to the token endpoint, which we tell to return # a dummy OIDC access token - ("https://issuer.test/token", {"access_token": "TEST"}), + (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), # and then one to the user_info endpoint, which returns our remote user id. - ("https://issuer.test/userinfo", user_info_dict), + (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), ] async def mock_req(method: str, uri: str, data=None, headers=None): @@ -542,25 +568,7 @@ class RestHelper: channel.extract_cookies(cookies) # parse the confirmation page to fish out the link. - class ConfirmationPageParser(HTMLParser): - def __init__(self): - super().__init__() - - self.links = [] # type: List[str] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "a": - href = attr_dict["href"] - if href: - self.links.append(href) - - def error(_, message): - raise AssertionError(message) - - p = ConfirmationPageParser() + p = TestHtmlParser() p.feed(channel.text_body) p.close() assert len(p.links) == 1, "not exactly one link in confirmation page" @@ -570,6 +578,8 @@ class RestHelper: # an 'oidc_config' suitable for login_via_oidc. TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" +TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" +TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" TEST_OIDC_CONFIG = { "enabled": True, "discover": False, @@ -578,7 +588,7 @@ TEST_OIDC_CONFIG = { "client_secret": "test-client-secret", "scopes": ["profile"], "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": "https://issuer.test/token", - "userinfo_endpoint": "https://issuer.test/userinfo", + "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, + "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, } diff --git a/tests/server.py b/tests/server.py index 5a1b66270f..5a85d5fe7f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -74,7 +74,7 @@ class FakeChannel: return int(self.result["code"]) @property - def headers(self): + def headers(self) -> Headers: if not self.result: raise Exception("No result yet.") h = Headers() diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py new file mode 100644 index 0000000000..ad563eb3f0 --- /dev/null +++ b/tests/test_utils/html_parsers.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from html.parser import HTMLParser +from typing import Dict, Iterable, List, Optional, Tuple + + +class TestHtmlParser(HTMLParser): + """A generic HTML page parser which extracts useful things from the HTML""" + + def __init__(self): + super().__init__() + + # a list of links found in the doc + self.links = [] # type: List[str] + + # the values of any hidden s: map from name to value + self.hiddens = {} # type: Dict[str, Optional[str]] + + # the values of any radio buttons: map from name to list of values + self.radios = {} # type: Dict[str, List[Optional[str]]] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "a": + href = attr_dict["href"] + if href: + self.links.append(href) + elif tag == "input": + input_name = attr_dict.get("name") + if attr_dict["type"] == "radio": + assert input_name + self.radios.setdefault(input_name, []).append(attr_dict["value"]) + elif attr_dict["type"] == "hidden": + assert input_name + self.hiddens[input_name] = attr_dict["value"] + + def error(_, message): + raise AssertionError(message) -- cgit 1.5.1 From fa50e4bf4ddcb8e98d44700513a28c490f80f02b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 12:30:41 +0000 Subject: Give `public_baseurl` a default value (#9159) --- changelog.d/9159.feature | 1 + docs/sample_config.yaml | 31 +++++++++++++++++-------------- synapse/api/urls.py | 2 -- synapse/config/_base.py | 11 ++++++----- synapse/config/emailconfig.py | 8 -------- synapse/config/oidc_config.py | 2 -- synapse/config/registration.py | 21 ++++----------------- synapse/config/saml2_config.py | 2 -- synapse/config/server.py | 24 +++++++++++++++--------- synapse/config/sso.py | 13 +++++-------- synapse/handlers/identity.py | 2 -- synapse/rest/well_known.py | 4 ---- tests/rest/test_well_known.py | 9 --------- tests/utils.py | 1 - 14 files changed, 48 insertions(+), 83 deletions(-) create mode 100644 changelog.d/9159.feature (limited to 'tests/rest') diff --git a/changelog.d/9159.feature b/changelog.d/9159.feature new file mode 100644 index 0000000000..b7748757de --- /dev/null +++ b/changelog.d/9159.feature @@ -0,0 +1 @@ +Give the `public_baseurl` a default value, if it is not explicitly set in the configuration file. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ae995efe9b..7fdd798d70 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -67,11 +67,16 @@ pid_file: DATADIR/homeserver.pid # #web_client_location: https://riot.example.com/ -# The public-facing base URL that clients use to access this HS -# (not including _matrix/...). This is the same URL a user would -# enter into the 'custom HS URL' field on their client. If you -# use synapse with a reverse proxy, this should be the URL to reach -# synapse via the proxy. +# The public-facing base URL that clients use to access this Homeserver (not +# including _matrix/...). This is the same URL a user might enter into the +# 'Custom Homeserver URL' field on their client. If you use Synapse with a +# reverse proxy, this should be the URL to reach Synapse via the proxy. +# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see +# 'listeners' below). +# +# If this is left unset, it defaults to 'https:///'. (Note that +# that will not work unless you configure Synapse or a reverse-proxy to listen +# on port 443.) # #public_baseurl: https://example.com/ @@ -1150,8 +1155,9 @@ account_validity: # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. + # If you enable this setting, you will also need to fill out the 'email' + # configuration section. You should also check that 'public_baseurl' is set + # correctly. # #renew_at: 1w @@ -1242,8 +1248,7 @@ account_validity: # The identity server which we suggest that clients should use when users log # in on this server. # -# (By default, no suggestion is made, so it is left up to the client. -# This setting is ignored unless public_baseurl is also set.) +# (By default, no suggestion is made, so it is left up to the client.) # #default_identity_server: https://matrix.org @@ -1268,8 +1273,6 @@ account_validity: # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # -# If a delegate is specified, the config option public_baseurl must also be filled out. -# account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process @@ -1901,9 +1904,9 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is automatically whitelisted in addition to any URLs + # in this list. # # By default, this list is empty. # diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 6379c86dde..e36aeef31f 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -42,8 +42,6 @@ class ConsentURIBuilder: """ if hs_config.form_secret is None: raise ConfigError("form_secret not set in config") - if hs_config.public_baseurl is None: - raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") self._public_baseurl = hs_config.public_baseurl diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 2931a88207..94144efc87 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -252,11 +252,12 @@ class Config: env = jinja2.Environment(loader=loader, autoescape=autoescape) # Update the environment with our custom filters - env.filters.update({"format_ts": _format_ts_filter}) - if self.public_baseurl: - env.filters.update( - {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)} - ) + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + } + ) for filename in filenames: # Load the template diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index d4328c46b9..6a487afd34 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -166,11 +166,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - # public_baseurl is required to build password reset and validation links that - # will be emailed to users - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),) @@ -269,9 +264,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( "email.enable_notifs is True but required keys are missing: %s" diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 80a24cfbc9..df55367434 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -43,8 +43,6 @@ class OIDCConfig(Config): raise ConfigError(e.message) from e public_baseurl = self.public_baseurl - if public_baseurl is None: - raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" @property diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 740c3fc1b1..4bfc69cb7a 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -49,10 +49,6 @@ class AccountValidityConfig(Config): self.startup_job_max_delta = self.period * 10.0 / 100.0 - if self.renew_by_email_enabled: - if "public_baseurl" not in synapse_config: - raise ConfigError("Can't send renewal emails without 'public_baseurl'") - template_dir = config.get("template_dir") if not template_dir: @@ -109,13 +105,6 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") - if self.account_threepid_delegate_msisdn and not self.public_baseurl: - raise ConfigError( - "The configuration option `public_baseurl` is required if " - "`account_threepid_delegate.msisdn` is set, such that " - "clients know where to submit validation tokens to. Please " - "configure `public_baseurl`." - ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) @@ -240,8 +229,9 @@ class RegistrationConfig(Config): # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. + # If you enable this setting, you will also need to fill out the 'email' + # configuration section. You should also check that 'public_baseurl' is set + # correctly. # #renew_at: 1w @@ -332,8 +322,7 @@ class RegistrationConfig(Config): # The identity server which we suggest that clients should use when users log # in on this server. # - # (By default, no suggestion is made, so it is left up to the client. - # This setting is ignored unless public_baseurl is also set.) + # (By default, no suggestion is made, so it is left up to the client.) # #default_identity_server: https://matrix.org @@ -358,8 +347,6 @@ class RegistrationConfig(Config): # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # - # If a delegate is specified, the config option public_baseurl must also be filled out. - # account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 7b97d4f114..f33dfa0d6a 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -189,8 +189,6 @@ class SAML2Config(Config): import saml2 public_baseurl = self.public_baseurl - if public_baseurl is None: - raise ConfigError("saml2_config requires a public_baseurl to be set") if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) diff --git a/synapse/config/server.py b/synapse/config/server.py index 7242a4aa8e..75ba161f35 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -161,7 +161,11 @@ class ServerConfig(Config): self.print_pidfile = config.get("print_pidfile") self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) - self.public_baseurl = config.get("public_baseurl") + self.public_baseurl = config.get("public_baseurl") or "https://%s/" % ( + self.server_name, + ) + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -317,9 +321,6 @@ class ServerConfig(Config): # Always blacklist 0.0.0.0, :: self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - if self.public_baseurl is not None: - if self.public_baseurl[-1] != "/": - self.public_baseurl += "/" self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, @@ -740,11 +741,16 @@ class ServerConfig(Config): # #web_client_location: https://riot.example.com/ - # The public-facing base URL that clients use to access this HS - # (not including _matrix/...). This is the same URL a user would - # enter into the 'custom HS URL' field on their client. If you - # use synapse with a reverse proxy, this should be the URL to reach - # synapse via the proxy. + # The public-facing base URL that clients use to access this Homeserver (not + # including _matrix/...). This is the same URL a user might enter into the + # 'Custom Homeserver URL' field on their client. If you use Synapse with a + # reverse proxy, this should be the URL to reach Synapse via the proxy. + # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see + # 'listeners' below). + # + # If this is left unset, it defaults to 'https:///'. (Note that + # that will not work unless you configure Synapse or a reverse-proxy to listen + # on port 443.) # #public_baseurl: https://example.com/ diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 366f0d4698..59be825532 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -64,11 +64,8 @@ class SSOConfig(Config): # gracefully to the client). This would make it pointless to ask the user for # confirmation, since the URL the confirmation page would be showing wouldn't be # the client's. - # public_baseurl is an optional setting, so we only add the fallback's URL to the - # list if it's provided (because we can't figure out what that URL is otherwise). - if self.public_baseurl: - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" - self.sso_client_whitelist.append(login_fallback_url) + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): return """\ @@ -86,9 +83,9 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is automatically whitelisted in addition to any URLs + # in this list. # # By default, this list is empty. # diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c05036ad1f..f61844d688 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -476,8 +476,6 @@ class IdentityHandler(BaseHandler): except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - assert self.hs.config.public_baseurl - # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index f591cc6c5c..241fe746d9 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -34,10 +34,6 @@ class WellKnownBuilder: self._config = hs.config def get_well_known(self): - # if we don't have a public_baseurl, we can't help much here. - if self._config.public_baseurl is None: - return None - result = {"m.homeserver": {"base_url": self._config.public_baseurl}} if self._config.default_identity_server: diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 14de0921be..c5e44af9f7 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -40,12 +40,3 @@ class WellKnownTests(unittest.HomeserverTestCase): "m.identity_server": {"base_url": "https://testis"}, }, ) - - def test_well_known_no_public_baseurl(self): - self.hs.config.public_baseurl = None - - channel = self.make_request( - "GET", "/.well-known/matrix/client", shorthand=False - ) - - self.assertEqual(channel.code, 404) diff --git a/tests/utils.py b/tests/utils.py index 977eeaf6ee..09614093bc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -159,7 +159,6 @@ def default_config(name, parse=False): "remote": {"per_second": 10000, "burst_count": 10000}, }, "saml2_enabled": False, - "public_baseurl": None, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, "old_signing_keys": {}, -- cgit 1.5.1 From 7447f197026db570c1c1af240642566b31f81e42 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 21 Jan 2021 12:25:02 +0000 Subject: Prefix idp_id with "oidc-" (#9189) ... to avoid clashes with other SSO mechanisms --- changelog.d/9189.misc | 1 + docs/sample_config.yaml | 13 +++++++++---- synapse/config/oidc_config.py | 28 ++++++++++++++++++++++++---- tests/rest/client/v1/test_login.py | 2 +- 4 files changed, 35 insertions(+), 9 deletions(-) create mode 100644 changelog.d/9189.misc (limited to 'tests/rest') diff --git a/changelog.d/9189.misc b/changelog.d/9189.misc new file mode 100644 index 0000000000..9a5740aac2 --- /dev/null +++ b/changelog.d/9189.misc @@ -0,0 +1 @@ +Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b49a5da8cc..87bfe22237 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1728,7 +1728,9 @@ saml2_config: # # idp_icon: An optional icon for this identity provider, which is presented # by identity picker pages. If given, must be an MXC URI of the format -# mxc:/// +# mxc:///. (An easy way to obtain such an MXC URI +# is to upload an image to an (unencrypted) room and then copy the "url" +# from the source of the event.) # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -1814,13 +1816,16 @@ saml2_config: # # For backwards compatibility, it is also possible to configure a single OIDC # provider via an 'oidc_config' setting. This is now deprecated and admins are -# advised to migrate to the 'oidc_providers' format. +# advised to migrate to the 'oidc_providers' format. (When doing that migration, +# use 'oidc' for the idp_id to ensure that existing users continue to be +# recognised.) # oidc_providers: # Generic example # #- idp_id: my_idp # idp_name: "My OpenID provider" + # idp_icon: "mxc://example.com/mediaid" # discover: false # issuer: "https://accounts.example.com/" # client_id: "provided-by-your-issuer" @@ -1844,8 +1849,8 @@ oidc_providers: # For use with Github # - #- idp_id: google - # idp_name: Google + #- idp_id: github + # idp_name: Github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 8cb0c42f36..d58a83be7f 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -69,7 +69,9 @@ class OIDCConfig(Config): # # idp_icon: An optional icon for this identity provider, which is presented # by identity picker pages. If given, must be an MXC URI of the format - # mxc:/// + # mxc:///. (An easy way to obtain such an MXC URI + # is to upload an image to an (unencrypted) room and then copy the "url" + # from the source of the event.) # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -155,13 +157,16 @@ class OIDCConfig(Config): # # For backwards compatibility, it is also possible to configure a single OIDC # provider via an 'oidc_config' setting. This is now deprecated and admins are - # advised to migrate to the 'oidc_providers' format. + # advised to migrate to the 'oidc_providers' format. (When doing that migration, + # use 'oidc' for the idp_id to ensure that existing users continue to be + # recognised.) # oidc_providers: # Generic example # #- idp_id: my_idp # idp_name: "My OpenID provider" + # idp_icon: "mxc://example.com/mediaid" # discover: false # issuer: "https://accounts.example.com/" # client_id: "provided-by-your-issuer" @@ -185,8 +190,8 @@ class OIDCConfig(Config): # For use with Github # - #- idp_id: google - # idp_name: Google + #- idp_id: github + # idp_name: Github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED @@ -210,6 +215,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "type": "object", "required": ["issuer", "client_id", "client_secret"], "properties": { + # TODO: fix the maxLength here depending on what MSC2528 decides + # remember that we prefix the ID given here with `oidc-` "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, "idp_name": {"type": "string"}, "idp_icon": {"type": "string"}, @@ -335,6 +342,8 @@ def _parse_oidc_config_dict( # enforce those limits now. # TODO: factor out this stuff to a generic function idp_id = oidc_config.get("idp_id", "oidc") + + # TODO: update this validity check based on what MSC2858 decides. valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._") if any(c not in valid_idp_chars for c in idp_id): @@ -348,6 +357,17 @@ def _parse_oidc_config_dict( "idp_id must start with a-z", config_path + ("idp_id",), ) + # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid + # clashes with other mechs (such as SAML, CAS). + # + # We allow "oidc" as an exception so that people migrating from old-style + # "oidc_config" format (which has long used "oidc" as its idp_id) can migrate to + # a new-style "oidc_providers" entry without changing the idp_id for their provider + # (and thereby invalidating their user_external_ids data). + + if idp_id != "oidc": + idp_id = "oidc-" + idp_id + # MSC2858 also specifies that the idp_icon must be a valid MXC uri idp_icon = oidc_config.get("idp_icon") if idp_icon is not None: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2d25490374..2672ce24c6 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -446,7 +446,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): p.feed(channel.result["body"].decode("utf-8")) p.close() - self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"]) + self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"]) self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) -- cgit 1.5.1 From c55e62548c0fddd49e7182133880d2ccb03dbb42 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 21 Jan 2021 15:18:46 +0100 Subject: Add tests for List Users Admin API (#9045) --- changelog.d/9045.misc | 1 + synapse/rest/admin/users.py | 21 +++- tests/rest/admin/test_user.py | 223 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 215 insertions(+), 30 deletions(-) create mode 100644 changelog.d/9045.misc (limited to 'tests/rest') diff --git a/changelog.d/9045.misc b/changelog.d/9045.misc new file mode 100644 index 0000000000..7f1886a0de --- /dev/null +++ b/changelog.d/9045.misc @@ -0,0 +1 @@ +Add tests to `test_user.UsersListTestCase` for List Users Admin API. \ No newline at end of file diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f39e3d6d5c..86198bab30 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet): The parameter `deactivated` can be used to include deactivated users. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + user_id = parse_string(request, "user_id", default=None) name = parse_string(request, "name", default=None) guests = parse_boolean(request, "guests", default=True) @@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet): start, limit, user_id, name, guests, deactivated ) ret = {"users": users, "total": total} - if len(users) >= limit: + if (start + limit) < total: ret["next_token"] = str(start + len(users)) return 200, ret diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 04599c2fcf..e48f8c1d7b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync +from synapse.types import JsonDict from tests import unittest from tests.test_utils import make_awaitable @@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.user1 = self.register_user( - "user1", "pass1", admin=False, displayname="Name 1" - ) - self.user2 = self.register_user( - "user2", "pass2", admin=False, displayname="Name 2" - ) - def test_no_auth(self): """ Try to list users without authentication. @@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error is returned. """ + self._create_users(1) other_user_token = self.login("user1", "pass1") channel = self.make_request("GET", self.url, access_token=other_user_token) @@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ List all users, including deactivated users. """ + self._create_users(2) + channel = self.make_request( "GET", self.url + "?deactivated=true", @@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(3, channel.json_body["total"]) # Check that all fields are available - for u in channel.json_body["users"]: - self.assertIn("name", u) - self.assertIn("is_guest", u) - self.assertIn("admin", u) - self.assertIn("user_type", u) - self.assertIn("deactivated", u) - self.assertIn("displayname", u) - self.assertIn("avatar_url", u) + self._check_fields(channel.json_body["users"]) def test_search_term(self): """Test that searching for a users works correctly""" @@ -549,6 +539,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): # Check that users were returned self.assertTrue("users" in channel.json_body) + self._check_fields(channel.json_body["users"]) users = channel.json_body["users"] # Check that the expected number of users were returned @@ -561,25 +552,30 @@ class UsersListTestCase(unittest.HomeserverTestCase): u = users[0] self.assertEqual(expected_user_id, u["name"]) + self._create_users(2) + + user1 = "@user1:test" + user2 = "@user2:test" + # Perform search tests - _search_test(self.user1, "er1") - _search_test(self.user1, "me 1") + _search_test(user1, "er1") + _search_test(user1, "me 1") - _search_test(self.user2, "er2") - _search_test(self.user2, "me 2") + _search_test(user2, "er2") + _search_test(user2, "me 2") - _search_test(self.user1, "er1", "user_id") - _search_test(self.user2, "er2", "user_id") + _search_test(user1, "er1", "user_id") + _search_test(user2, "er2", "user_id") # Test case insensitive - _search_test(self.user1, "ER1") - _search_test(self.user1, "NAME 1") + _search_test(user1, "ER1") + _search_test(user1, "NAME 1") - _search_test(self.user2, "ER2") - _search_test(self.user2, "NAME 2") + _search_test(user2, "ER2") + _search_test(user2, "NAME 2") - _search_test(self.user1, "ER1", "user_id") - _search_test(self.user2, "ER2", "user_id") + _search_test(user1, "ER1", "user_id") + _search_test(user2, "ER2", "user_id") _search_test(None, "foo") _search_test(None, "bar") @@ -587,6 +583,179 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "foo", "user_id") _search_test(None, "bar", "user_id") + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + + # negative limit + channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # invalid guests + channel = self.make_request( + "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid deactivated + channel = self.make_request( + "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + def test_limit(self): + """ + Testing list of users with limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 5) + self.assertEqual(channel.json_body["next_token"], "5") + self._check_fields(channel.json_body["users"]) + + def test_from(self): + """ + Testing list of users with a defined starting point (from) + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["users"]) + + def test_limit_and_from(self): + """ + Testing list of users with a defined starting point and limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(channel.json_body["next_token"], "15") + self.assertEqual(len(channel.json_body["users"]), 10) + self._check_fields(channel.json_body["users"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 19) + self.assertEqual(channel.json_body["next_token"], "19") + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + Args: + content: List that is checked for content + """ + for u in content: + self.assertIn("name", u) + self.assertIn("is_guest", u) + self.assertIn("admin", u) + self.assertIn("user_type", u) + self.assertIn("deactivated", u) + self.assertIn("displayname", u) + self.assertIn("avatar_url", u) + + def _create_users(self, number_users: int): + """ + Create a number of users + Args: + number_users: Number of users to be created + """ + for i in range(1, number_users + 1): + self.register_user( + "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i, + ) + class DeactivateAccountTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From a7882f98874684969910d3a6ed7d85f99114cc45 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 21 Jan 2021 14:53:58 -0500 Subject: Return a 404 if no valid thumbnail is found. (#9163) If no thumbnail of the requested type exists, return a 404 instead of erroring. This doesn't quite match the spec (which does not define what happens if no thumbnail can be found), but is consistent with what Synapse already does. --- changelog.d/9163.bugfix | 1 + synapse/rest/media/v1/_base.py | 3 + synapse/rest/media/v1/thumbnail_resource.py | 236 ++++++++++++++++++---------- tests/rest/media/v1/test_media_storage.py | 25 ++- 4 files changed, 183 insertions(+), 82 deletions(-) create mode 100644 changelog.d/9163.bugfix (limited to 'tests/rest') diff --git a/changelog.d/9163.bugfix b/changelog.d/9163.bugfix new file mode 100644 index 0000000000..c51cf6ca80 --- /dev/null +++ b/changelog.d/9163.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled). diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 31a41e4a27..f71a03a12d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -300,6 +300,7 @@ class FileInfo: thumbnail_height (int) thumbnail_method (str) thumbnail_type (str): Content type of thumbnail, e.g. image/png + thumbnail_length (int): The size of the media file, in bytes. """ def __init__( @@ -312,6 +313,7 @@ class FileInfo: thumbnail_height=None, thumbnail_method=None, thumbnail_type=None, + thumbnail_length=None, ): self.server_name = server_name self.file_id = file_id @@ -321,6 +323,7 @@ class FileInfo: self.thumbnail_height = thumbnail_height self.thumbnail_method = thumbnail_method self.thumbnail_type = thumbnail_type + self.thumbnail_length = thumbnail_length def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d6880f2e6e..d653a58be9 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,7 +16,7 @@ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from twisted.web.http import Request @@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource): return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) - - if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos - ) - - file_info = FileInfo( - server_name=None, - file_id=media_id, - url_cache=media_info["url_cache"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] - - responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) - else: - logger.info("Couldn't find any generated thumbnails") - respond_404(request) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_id, + url_cache=media_info["url_cache"], + server_name=None, + ) async def _select_or_generate_local_thumbnail( self, @@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource): thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_info["filesystem_id"], + url_cache=None, + server_name=server_name, + ) + async def _select_and_respond_with_thumbnail( + self, + request: Request, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str] = None, + server_name: Optional[str] = None, + ) -> None: + """ + Respond to a request with an appropriate thumbnail from the previously generated thumbnails. + + Args: + request: The incoming request. + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + """ if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos + file_info = self._select_thumbnail( + desired_width, + desired_height, + desired_method, + desired_type, + thumbnail_infos, + file_id, + url_cache, + server_name, ) - file_info = FileInfo( - server_name=server_name, - file_id=media_info["filesystem_id"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] + if not file_info: + logger.info("Couldn't find a thumbnail matching the desired inputs") + respond_404(request) + return responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder( + request, responder, file_info.thumbnail_type, file_info.thumbnail_length + ) else: logger.info("Failed to find any generated thumbnails") respond_404(request) @@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource): desired_height: int, desired_method: str, desired_type: str, - thumbnail_infos, - ) -> dict: + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str], + server_name: Optional[str], + ) -> Optional[FileInfo]: + """ + Choose an appropriate thumbnail from the previously generated thumbnails. + + Args: + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + + Returns: + The thumbnail which best matches the desired parameters. + """ + desired_method = desired_method.lower() + + # The chosen thumbnail. + thumbnail_info = None + d_w = desired_width d_h = desired_height - if desired_method.lower() == "crop": + if desired_method == "crop": + # Thumbnails that match equal or larger sizes of desired width/height. crop_info_list = [] + # Other thumbnails. crop_info_list2 = [] for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "crop": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] - if t_method == "crop": - aspect_quality = abs(d_w * t_h - d_h * t_w) - min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 - size_quality = abs((d_w - t_w) * (d_h - t_h)) - type_quality = desired_type != info["thumbnail_type"] - length_quality = info["thumbnail_length"] - if t_w >= d_w or t_h >= d_h: - crop_info_list.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + aspect_quality = abs(d_w * t_h - d_h * t_w) + min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 + size_quality = abs((d_w - t_w) * (d_h - t_h)) + type_quality = desired_type != info["thumbnail_type"] + length_quality = info["thumbnail_length"] + if t_w >= d_w or t_h >= d_h: + crop_info_list.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) - else: - crop_info_list2.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + ) + else: + crop_info_list2.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) + ) if crop_info_list: - return min(crop_info_list)[-1] - else: - return min(crop_info_list2)[-1] - else: + thumbnail_info = min(crop_info_list)[-1] + elif crop_info_list2: + thumbnail_info = min(crop_info_list2)[-1] + elif desired_method == "scale": + # Thumbnails that match equal or larger sizes of desired width/height. info_list = [] + # Other thumbnails. info_list2 = [] + for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "scale": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] size_quality = abs((d_w - t_w) * (d_h - t_h)) type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] - if t_method == "scale" and (t_w >= d_w or t_h >= d_h): + if t_w >= d_w or t_h >= d_h: info_list.append((size_quality, type_quality, length_quality, info)) - elif t_method == "scale": + else: info_list2.append( (size_quality, type_quality, length_quality, info) ) if info_list: - return min(info_list)[-1] - else: - return min(info_list2)[-1] + thumbnail_info = min(info_list)[-1] + elif info_list2: + thumbnail_info = min(info_list2)[-1] + + if thumbnail_info: + return FileInfo( + file_id=file_id, + url_cache=url_cache, + server_name=server_name, + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + thumbnail_length=thumbnail_info["thumbnail_length"], + ) + + # No matching thumbnail was found. + return None diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index ae2b32b131..a6c6985173 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): config = self.default_config() config["media_store_path"] = self.media_store_path - config["thumbnail_requirements"] = {} config["max_image_pixels"] = 2000000 provider_config = { @@ -313,15 +312,39 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) def test_thumbnail_crop(self): + """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( "crop", self.test_image.expected_cropped, self.test_image.expected_found ) def test_thumbnail_scale(self): + """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( "scale", self.test_image.expected_scaled, self.test_image.expected_found ) + def test_invalid_type(self): + """An invalid thumbnail type is never available.""" + self._test_thumbnail("invalid", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} + ) + def test_no_thumbnail_crop(self): + """ + Override the config to generate only scaled thumbnails, but request a cropped one. + """ + self._test_thumbnail("crop", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} + ) + def test_no_thumbnail_scale(self): + """ + Override the config to generate only cropped thumbnails, but request a scaled one. + """ + self._test_thumbnail("scale", None, False) + def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method channel = make_request( -- cgit 1.5.1