From fa50e4bf4ddcb8e98d44700513a28c490f80f02b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 12:30:41 +0000 Subject: Give `public_baseurl` a default value (#9159) --- tests/rest/test_well_known.py | 9 --------- 1 file changed, 9 deletions(-) (limited to 'tests/rest') diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 14de0921be..c5e44af9f7 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -40,12 +40,3 @@ class WellKnownTests(unittest.HomeserverTestCase): "m.identity_server": {"base_url": "https://testis"}, }, ) - - def test_well_known_no_public_baseurl(self): - self.hs.config.public_baseurl = None - - channel = self.make_request( - "GET", "/.well-known/matrix/client", shorthand=False - ) - - self.assertEqual(channel.code, 404) -- cgit 1.5.1 From 7447f197026db570c1c1af240642566b31f81e42 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 21 Jan 2021 12:25:02 +0000 Subject: Prefix idp_id with "oidc-" (#9189) ... to avoid clashes with other SSO mechanisms --- changelog.d/9189.misc | 1 + docs/sample_config.yaml | 13 +++++++++---- synapse/config/oidc_config.py | 28 ++++++++++++++++++++++++---- tests/rest/client/v1/test_login.py | 2 +- 4 files changed, 35 insertions(+), 9 deletions(-) create mode 100644 changelog.d/9189.misc (limited to 'tests/rest') diff --git a/changelog.d/9189.misc b/changelog.d/9189.misc new file mode 100644 index 0000000000..9a5740aac2 --- /dev/null +++ b/changelog.d/9189.misc @@ -0,0 +1 @@ +Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b49a5da8cc..87bfe22237 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1728,7 +1728,9 @@ saml2_config: # # idp_icon: An optional icon for this identity provider, which is presented # by identity picker pages. If given, must be an MXC URI of the format -# mxc:/// +# mxc:///. (An easy way to obtain such an MXC URI +# is to upload an image to an (unencrypted) room and then copy the "url" +# from the source of the event.) # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -1814,13 +1816,16 @@ saml2_config: # # For backwards compatibility, it is also possible to configure a single OIDC # provider via an 'oidc_config' setting. This is now deprecated and admins are -# advised to migrate to the 'oidc_providers' format. +# advised to migrate to the 'oidc_providers' format. (When doing that migration, +# use 'oidc' for the idp_id to ensure that existing users continue to be +# recognised.) # oidc_providers: # Generic example # #- idp_id: my_idp # idp_name: "My OpenID provider" + # idp_icon: "mxc://example.com/mediaid" # discover: false # issuer: "https://accounts.example.com/" # client_id: "provided-by-your-issuer" @@ -1844,8 +1849,8 @@ oidc_providers: # For use with Github # - #- idp_id: google - # idp_name: Google + #- idp_id: github + # idp_name: Github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 8cb0c42f36..d58a83be7f 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -69,7 +69,9 @@ class OIDCConfig(Config): # # idp_icon: An optional icon for this identity provider, which is presented # by identity picker pages. If given, must be an MXC URI of the format - # mxc:/// + # mxc:///. (An easy way to obtain such an MXC URI + # is to upload an image to an (unencrypted) room and then copy the "url" + # from the source of the event.) # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -155,13 +157,16 @@ class OIDCConfig(Config): # # For backwards compatibility, it is also possible to configure a single OIDC # provider via an 'oidc_config' setting. This is now deprecated and admins are - # advised to migrate to the 'oidc_providers' format. + # advised to migrate to the 'oidc_providers' format. (When doing that migration, + # use 'oidc' for the idp_id to ensure that existing users continue to be + # recognised.) # oidc_providers: # Generic example # #- idp_id: my_idp # idp_name: "My OpenID provider" + # idp_icon: "mxc://example.com/mediaid" # discover: false # issuer: "https://accounts.example.com/" # client_id: "provided-by-your-issuer" @@ -185,8 +190,8 @@ class OIDCConfig(Config): # For use with Github # - #- idp_id: google - # idp_name: Google + #- idp_id: github + # idp_name: Github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED @@ -210,6 +215,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "type": "object", "required": ["issuer", "client_id", "client_secret"], "properties": { + # TODO: fix the maxLength here depending on what MSC2528 decides + # remember that we prefix the ID given here with `oidc-` "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, "idp_name": {"type": "string"}, "idp_icon": {"type": "string"}, @@ -335,6 +342,8 @@ def _parse_oidc_config_dict( # enforce those limits now. # TODO: factor out this stuff to a generic function idp_id = oidc_config.get("idp_id", "oidc") + + # TODO: update this validity check based on what MSC2858 decides. valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._") if any(c not in valid_idp_chars for c in idp_id): @@ -348,6 +357,17 @@ def _parse_oidc_config_dict( "idp_id must start with a-z", config_path + ("idp_id",), ) + # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid + # clashes with other mechs (such as SAML, CAS). + # + # We allow "oidc" as an exception so that people migrating from old-style + # "oidc_config" format (which has long used "oidc" as its idp_id) can migrate to + # a new-style "oidc_providers" entry without changing the idp_id for their provider + # (and thereby invalidating their user_external_ids data). + + if idp_id != "oidc": + idp_id = "oidc-" + idp_id + # MSC2858 also specifies that the idp_icon must be a valid MXC uri idp_icon = oidc_config.get("idp_icon") if idp_icon is not None: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2d25490374..2672ce24c6 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -446,7 +446,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): p.feed(channel.result["body"].decode("utf-8")) p.close() - self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"]) + self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"]) self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) -- cgit 1.5.1 From c55e62548c0fddd49e7182133880d2ccb03dbb42 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 21 Jan 2021 15:18:46 +0100 Subject: Add tests for List Users Admin API (#9045) --- changelog.d/9045.misc | 1 + synapse/rest/admin/users.py | 21 +++- tests/rest/admin/test_user.py | 223 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 215 insertions(+), 30 deletions(-) create mode 100644 changelog.d/9045.misc (limited to 'tests/rest') diff --git a/changelog.d/9045.misc b/changelog.d/9045.misc new file mode 100644 index 0000000000..7f1886a0de --- /dev/null +++ b/changelog.d/9045.misc @@ -0,0 +1 @@ +Add tests to `test_user.UsersListTestCase` for List Users Admin API. \ No newline at end of file diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f39e3d6d5c..86198bab30 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet): The parameter `deactivated` can be used to include deactivated users. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + user_id = parse_string(request, "user_id", default=None) name = parse_string(request, "name", default=None) guests = parse_boolean(request, "guests", default=True) @@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet): start, limit, user_id, name, guests, deactivated ) ret = {"users": users, "total": total} - if len(users) >= limit: + if (start + limit) < total: ret["next_token"] = str(start + len(users)) return 200, ret diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 04599c2fcf..e48f8c1d7b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync +from synapse.types import JsonDict from tests import unittest from tests.test_utils import make_awaitable @@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.user1 = self.register_user( - "user1", "pass1", admin=False, displayname="Name 1" - ) - self.user2 = self.register_user( - "user2", "pass2", admin=False, displayname="Name 2" - ) - def test_no_auth(self): """ Try to list users without authentication. @@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error is returned. """ + self._create_users(1) other_user_token = self.login("user1", "pass1") channel = self.make_request("GET", self.url, access_token=other_user_token) @@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ List all users, including deactivated users. """ + self._create_users(2) + channel = self.make_request( "GET", self.url + "?deactivated=true", @@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(3, channel.json_body["total"]) # Check that all fields are available - for u in channel.json_body["users"]: - self.assertIn("name", u) - self.assertIn("is_guest", u) - self.assertIn("admin", u) - self.assertIn("user_type", u) - self.assertIn("deactivated", u) - self.assertIn("displayname", u) - self.assertIn("avatar_url", u) + self._check_fields(channel.json_body["users"]) def test_search_term(self): """Test that searching for a users works correctly""" @@ -549,6 +539,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): # Check that users were returned self.assertTrue("users" in channel.json_body) + self._check_fields(channel.json_body["users"]) users = channel.json_body["users"] # Check that the expected number of users were returned @@ -561,25 +552,30 @@ class UsersListTestCase(unittest.HomeserverTestCase): u = users[0] self.assertEqual(expected_user_id, u["name"]) + self._create_users(2) + + user1 = "@user1:test" + user2 = "@user2:test" + # Perform search tests - _search_test(self.user1, "er1") - _search_test(self.user1, "me 1") + _search_test(user1, "er1") + _search_test(user1, "me 1") - _search_test(self.user2, "er2") - _search_test(self.user2, "me 2") + _search_test(user2, "er2") + _search_test(user2, "me 2") - _search_test(self.user1, "er1", "user_id") - _search_test(self.user2, "er2", "user_id") + _search_test(user1, "er1", "user_id") + _search_test(user2, "er2", "user_id") # Test case insensitive - _search_test(self.user1, "ER1") - _search_test(self.user1, "NAME 1") + _search_test(user1, "ER1") + _search_test(user1, "NAME 1") - _search_test(self.user2, "ER2") - _search_test(self.user2, "NAME 2") + _search_test(user2, "ER2") + _search_test(user2, "NAME 2") - _search_test(self.user1, "ER1", "user_id") - _search_test(self.user2, "ER2", "user_id") + _search_test(user1, "ER1", "user_id") + _search_test(user2, "ER2", "user_id") _search_test(None, "foo") _search_test(None, "bar") @@ -587,6 +583,179 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "foo", "user_id") _search_test(None, "bar", "user_id") + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + + # negative limit + channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # invalid guests + channel = self.make_request( + "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid deactivated + channel = self.make_request( + "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + def test_limit(self): + """ + Testing list of users with limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 5) + self.assertEqual(channel.json_body["next_token"], "5") + self._check_fields(channel.json_body["users"]) + + def test_from(self): + """ + Testing list of users with a defined starting point (from) + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["users"]) + + def test_limit_and_from(self): + """ + Testing list of users with a defined starting point and limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(channel.json_body["next_token"], "15") + self.assertEqual(len(channel.json_body["users"]), 10) + self._check_fields(channel.json_body["users"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 19) + self.assertEqual(channel.json_body["next_token"], "19") + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + Args: + content: List that is checked for content + """ + for u in content: + self.assertIn("name", u) + self.assertIn("is_guest", u) + self.assertIn("admin", u) + self.assertIn("user_type", u) + self.assertIn("deactivated", u) + self.assertIn("displayname", u) + self.assertIn("avatar_url", u) + + def _create_users(self, number_users: int): + """ + Create a number of users + Args: + number_users: Number of users to be created + """ + for i in range(1, number_users + 1): + self.register_user( + "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i, + ) + class DeactivateAccountTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From a7882f98874684969910d3a6ed7d85f99114cc45 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 21 Jan 2021 14:53:58 -0500 Subject: Return a 404 if no valid thumbnail is found. (#9163) If no thumbnail of the requested type exists, return a 404 instead of erroring. This doesn't quite match the spec (which does not define what happens if no thumbnail can be found), but is consistent with what Synapse already does. --- changelog.d/9163.bugfix | 1 + synapse/rest/media/v1/_base.py | 3 + synapse/rest/media/v1/thumbnail_resource.py | 236 ++++++++++++++++++---------- tests/rest/media/v1/test_media_storage.py | 25 ++- 4 files changed, 183 insertions(+), 82 deletions(-) create mode 100644 changelog.d/9163.bugfix (limited to 'tests/rest') diff --git a/changelog.d/9163.bugfix b/changelog.d/9163.bugfix new file mode 100644 index 0000000000..c51cf6ca80 --- /dev/null +++ b/changelog.d/9163.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled). diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 31a41e4a27..f71a03a12d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -300,6 +300,7 @@ class FileInfo: thumbnail_height (int) thumbnail_method (str) thumbnail_type (str): Content type of thumbnail, e.g. image/png + thumbnail_length (int): The size of the media file, in bytes. """ def __init__( @@ -312,6 +313,7 @@ class FileInfo: thumbnail_height=None, thumbnail_method=None, thumbnail_type=None, + thumbnail_length=None, ): self.server_name = server_name self.file_id = file_id @@ -321,6 +323,7 @@ class FileInfo: self.thumbnail_height = thumbnail_height self.thumbnail_method = thumbnail_method self.thumbnail_type = thumbnail_type + self.thumbnail_length = thumbnail_length def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d6880f2e6e..d653a58be9 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,7 +16,7 @@ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from twisted.web.http import Request @@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource): return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) - - if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos - ) - - file_info = FileInfo( - server_name=None, - file_id=media_id, - url_cache=media_info["url_cache"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] - - responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) - else: - logger.info("Couldn't find any generated thumbnails") - respond_404(request) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_id, + url_cache=media_info["url_cache"], + server_name=None, + ) async def _select_or_generate_local_thumbnail( self, @@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource): thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_info["filesystem_id"], + url_cache=None, + server_name=server_name, + ) + async def _select_and_respond_with_thumbnail( + self, + request: Request, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str] = None, + server_name: Optional[str] = None, + ) -> None: + """ + Respond to a request with an appropriate thumbnail from the previously generated thumbnails. + + Args: + request: The incoming request. + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + """ if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos + file_info = self._select_thumbnail( + desired_width, + desired_height, + desired_method, + desired_type, + thumbnail_infos, + file_id, + url_cache, + server_name, ) - file_info = FileInfo( - server_name=server_name, - file_id=media_info["filesystem_id"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] + if not file_info: + logger.info("Couldn't find a thumbnail matching the desired inputs") + respond_404(request) + return responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder( + request, responder, file_info.thumbnail_type, file_info.thumbnail_length + ) else: logger.info("Failed to find any generated thumbnails") respond_404(request) @@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource): desired_height: int, desired_method: str, desired_type: str, - thumbnail_infos, - ) -> dict: + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str], + server_name: Optional[str], + ) -> Optional[FileInfo]: + """ + Choose an appropriate thumbnail from the previously generated thumbnails. + + Args: + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + + Returns: + The thumbnail which best matches the desired parameters. + """ + desired_method = desired_method.lower() + + # The chosen thumbnail. + thumbnail_info = None + d_w = desired_width d_h = desired_height - if desired_method.lower() == "crop": + if desired_method == "crop": + # Thumbnails that match equal or larger sizes of desired width/height. crop_info_list = [] + # Other thumbnails. crop_info_list2 = [] for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "crop": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] - if t_method == "crop": - aspect_quality = abs(d_w * t_h - d_h * t_w) - min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 - size_quality = abs((d_w - t_w) * (d_h - t_h)) - type_quality = desired_type != info["thumbnail_type"] - length_quality = info["thumbnail_length"] - if t_w >= d_w or t_h >= d_h: - crop_info_list.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + aspect_quality = abs(d_w * t_h - d_h * t_w) + min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 + size_quality = abs((d_w - t_w) * (d_h - t_h)) + type_quality = desired_type != info["thumbnail_type"] + length_quality = info["thumbnail_length"] + if t_w >= d_w or t_h >= d_h: + crop_info_list.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) - else: - crop_info_list2.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + ) + else: + crop_info_list2.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) + ) if crop_info_list: - return min(crop_info_list)[-1] - else: - return min(crop_info_list2)[-1] - else: + thumbnail_info = min(crop_info_list)[-1] + elif crop_info_list2: + thumbnail_info = min(crop_info_list2)[-1] + elif desired_method == "scale": + # Thumbnails that match equal or larger sizes of desired width/height. info_list = [] + # Other thumbnails. info_list2 = [] + for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "scale": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] size_quality = abs((d_w - t_w) * (d_h - t_h)) type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] - if t_method == "scale" and (t_w >= d_w or t_h >= d_h): + if t_w >= d_w or t_h >= d_h: info_list.append((size_quality, type_quality, length_quality, info)) - elif t_method == "scale": + else: info_list2.append( (size_quality, type_quality, length_quality, info) ) if info_list: - return min(info_list)[-1] - else: - return min(info_list2)[-1] + thumbnail_info = min(info_list)[-1] + elif info_list2: + thumbnail_info = min(info_list2)[-1] + + if thumbnail_info: + return FileInfo( + file_id=file_id, + url_cache=url_cache, + server_name=server_name, + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + thumbnail_length=thumbnail_info["thumbnail_length"], + ) + + # No matching thumbnail was found. + return None diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index ae2b32b131..a6c6985173 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): config = self.default_config() config["media_store_path"] = self.media_store_path - config["thumbnail_requirements"] = {} config["max_image_pixels"] = 2000000 provider_config = { @@ -313,15 +312,39 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) def test_thumbnail_crop(self): + """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( "crop", self.test_image.expected_cropped, self.test_image.expected_found ) def test_thumbnail_scale(self): + """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( "scale", self.test_image.expected_scaled, self.test_image.expected_found ) + def test_invalid_type(self): + """An invalid thumbnail type is never available.""" + self._test_thumbnail("invalid", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} + ) + def test_no_thumbnail_crop(self): + """ + Override the config to generate only scaled thumbnails, but request a cropped one. + """ + self._test_thumbnail("crop", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} + ) + def test_no_thumbnail_scale(self): + """ + Override the config to generate only cropped thumbnails, but request a scaled one. + """ + self._test_thumbnail("scale", None, False) + def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method channel = make_request( -- cgit 1.5.1 From 4a55d267eef1388690e6781b580910e341358f95 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 25 Jan 2021 14:49:39 -0500 Subject: Add an admin API for shadow-banning users. (#9209) This expands the current shadow-banning feature to be usable via the admin API and adds documentation for it. A shadow-banned users receives successful responses to their client-server API requests, but the events are not propagated into rooms. Shadow-banning a user should be used as a tool of last resort and may lead to confusing or broken behaviour for the client. --- changelog.d/9209.feature | 1 + docs/admin_api/user_admin_api.rst | 30 ++++++++++++ stubs/txredisapi.pyi | 1 - synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 36 +++++++++++++++ synapse/storage/databases/main/registration.py | 29 ++++++++++++ tests/rest/admin/test_user.py | 64 ++++++++++++++++++++++++++ tests/rest/client/test_shadow_banned.py | 8 +--- 8 files changed, 164 insertions(+), 7 deletions(-) create mode 100644 changelog.d/9209.feature (limited to 'tests/rest') diff --git a/changelog.d/9209.feature b/changelog.d/9209.feature new file mode 100644 index 0000000000..ec926e8eb4 --- /dev/null +++ b/changelog.d/9209.feature @@ -0,0 +1 @@ +Add an admin API endpoint for shadow-banning users. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index b3d413cf57..1eb674939e 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -760,3 +760,33 @@ The following fields are returned in the JSON response body: - ``total`` - integer - Number of pushers. See also `Client-Server API Spec `_ + +Shadow-banning users +==================== + +Shadow-banning is a useful tool for moderating malicious or egregiously abusive users. +A shadow-banned users receives successful responses to their client-server API requests, +but the events are not propagated into rooms. This can be an effective tool as it +(hopefully) takes longer for the user to realise they are being moderated before +pivoting to another account. + +Shadow-banning a user should be used as a tool of last resort and may lead to confusing +or broken behaviour for the client. A shadow-banned user will not receive any +notification and it is generally more appropriate to ban or kick abusive users. +A shadow-banned user will be unable to contact anyone on the server. + +The API is:: + + POST /_synapse/admin/v1/users//shadow_ban + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst `_. + +An empty JSON dict is returned. + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must + be local. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index bfac6840e6..726454ba31 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -15,7 +15,6 @@ """Contains *incomplete* type hints for txredisapi. """ - from typing import List, Optional, Type, Union class RedisProtocol: diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6f7dc06503..f04740cd38 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -51,6 +51,7 @@ from synapse.rest.admin.users import ( PushersRestServlet, ResetPasswordRestServlet, SearchUsersRestServlet, + ShadowBanRestServlet, UserAdminServlet, UserMediaRestServlet, UserMembershipRestServlet, @@ -230,6 +231,7 @@ def register_servlets(hs, http_server): EventReportsRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server) + ShadowBanRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 86198bab30..68c3c64a0d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -890,3 +890,39 @@ class UserTokenRestServlet(RestServlet): ) return 200, {"access_token": token} + + +class ShadowBanRestServlet(RestServlet): + """An admin API for shadow-banning a user. + + A shadow-banned users receives successful responses to their client-server + API requests, but the events are not propagated into rooms. + + Shadow-banning a user should be used as a tool of last resort and may lead + to confusing or broken behaviour for the client. + + Example: + + POST /_synapse/admin/v1/users/@test:example.com/shadow_ban + {} + + 200 OK + {} + """ + + PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Only local users can be shadow-banned") + + await self.store.set_shadow_banned(UserID.from_string(user_id), True) + + return 200, {} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 585b4049d6..0618b4387a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) + async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None: + """Sets whether a user shadow-banned. + + Args: + user: user ID of the user to test + shadow_banned: true iff the user is to be shadow-banned, false otherwise. + """ + + def set_shadow_banned_txn(txn): + self.db_pool.simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user.to_string()}, + updatevalues={"shadow_banned": shadow_banned}, + ) + # In order for this to apply immediately, clear the cache for this user. + tokens = self.db_pool.simple_select_onecol_txn( + txn, + table="access_tokens", + keyvalues={"user_id": user.to_string()}, + retcol="token", + ) + for token in tokens: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (token,) + ) + + await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: sql = """ SELECT users.name as user_id, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e48f8c1d7b..ee05ee60bc 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2380,3 +2380,67 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) + + +class ShadowBanRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + + self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to get information of an user without authentication. + """ + channel = self.make_request("POST", self.url) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + channel = self.make_request("POST", self.url, access_token=other_user_token) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that shadow-banning for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" + + channel = self.make_request("POST", url, access_token=self.admin_user_tok) + self.assertEqual(400, channel.code, msg=channel.json_body) + + def test_success(self): + """ + Shadow-banning should succeed for an admin. + """ + # The user starts off as not shadow-banned. + other_user_token = self.login("user", "pass") + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertFalse(result.shadow_banned) + + channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + # Ensure the user is shadow-banned (and the cache was cleared). + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertTrue(result.shadow_banned) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index e689c3fbea..0ebdf1415b 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -18,6 +18,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.types import UserID from tests import unittest @@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase): self.store = self.hs.get_datastore() self.get_success( - self.store.db_pool.simple_update( - table="users", - keyvalues={"name": self.banned_user_id}, - updatevalues={"shadow_banned": True}, - desc="shadow_ban", - ) + self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True) ) self.other_user_id = self.register_user("otheruser", "pass") -- cgit 1.5.1 From a737cc27134c50059440ca33510b0baea53b4225 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 27 Jan 2021 12:41:24 +0000 Subject: Implement MSC2858 support (#9183) Fixes #8928. --- changelog.d/9183.feature | 1 + synapse/config/_base.pyi | 2 + synapse/config/experimental.py | 29 ++++++++++++ synapse/config/homeserver.py | 2 + synapse/handlers/sso.py | 23 +++++++--- synapse/http/server.py | 44 ++++++++++++++---- synapse/rest/client/v1/login.py | 55 ++++++++++++++++++++--- tests/rest/client/v1/test_login.py | 92 ++++++++++++++++++++++++++++++++++++++ tests/utils.py | 3 +- 9 files changed, 230 insertions(+), 21 deletions(-) create mode 100644 changelog.d/9183.feature create mode 100644 synapse/config/experimental.py (limited to 'tests/rest') diff --git a/changelog.d/9183.feature b/changelog.d/9183.feature new file mode 100644 index 0000000000..2d5c735042 --- /dev/null +++ b/changelog.d/9183.feature @@ -0,0 +1 @@ +Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858). diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 29aa064e57..3ccea4b02d 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -9,6 +9,7 @@ from synapse.config import ( consent_config, database, emailconfig, + experimental, groups, jwt_config, key, @@ -48,6 +49,7 @@ def path_exists(file_path: str): ... class RootConfig: server: server.ServerConfig + experimental: experimental.ExperimentalConfig tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py new file mode 100644 index 0000000000..b1c1c51e4d --- /dev/null +++ b/synapse/config/experimental.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config +from synapse.types import JsonDict + + +class ExperimentalConfig(Config): + """Config section for enabling experimental features""" + + section = "experimental" + + def read_config(self, config: JsonDict, **kwargs): + experimental = config.get("experimental_features") or {} + + # MSC2858 (multiple SSO identity providers) + self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 4bd2b3587b..64a2429f77 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -24,6 +24,7 @@ from .cas import CasConfig from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig +from .experimental import ExperimentalConfig from .federation import FederationConfig from .groups import GroupsConfig from .jwt_config import JWTConfig @@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, + ExperimentalConfig, TlsConfig, FederationConfig, CacheConfig, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d493327a10..afc1341d09 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -23,7 +23,7 @@ from typing_extensions import NoReturn, Protocol from twisted.web.http import Request from synapse.api.constants import LoginType -from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html @@ -235,7 +235,10 @@ class SsoHandler: respond_with_html(request, code, html) async def handle_redirect_request( - self, request: SynapseRequest, client_redirect_url: bytes, + self, + request: SynapseRequest, + client_redirect_url: bytes, + idp_id: Optional[str], ) -> str: """Handle a request to /login/sso/redirect @@ -243,6 +246,7 @@ class SsoHandler: request: incoming HTTP request client_redirect_url: the URL that we should redirect the client to after login. + idp_id: optional identity provider chosen by the client Returns: the URI to redirect to @@ -252,10 +256,19 @@ class SsoHandler: 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED ) + # if the client chose an IdP, use that + idp = None # type: Optional[SsoIdentityProvider] + if idp_id: + idp = self._identity_providers.get(idp_id) + if not idp: + raise NotFoundError("Unknown identity provider") + # if we only have one auth provider, redirect to it directly - if len(self._identity_providers) == 1: - ap = next(iter(self._identity_providers.values())) - return await ap.handle_redirect_request(request, client_redirect_url) + elif len(self._identity_providers) == 1: + idp = next(iter(self._identity_providers.values())) + + if idp: + return await idp.handle_redirect_request(request, client_redirect_url) # otherwise, redirect to the IDP picker return "/_synapse/client/pick_idp?" + urlencode( diff --git a/synapse/http/server.py b/synapse/http/server.py index e464bfe6c7..d69d579b3a 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -22,10 +22,22 @@ import types import urllib from http import HTTPStatus from io import BytesIO -from typing import Any, Callable, Dict, Iterator, List, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Pattern, + Tuple, + Union, +) import jinja2 from canonicaljson import iterencode_canonical_json +from typing_extensions import Protocol from zope.interface import implementer from twisted.internet import defer, interfaces @@ -168,11 +180,25 @@ def wrap_async_request_handler(h): return preserve_fn(wrapped_async_request_handler) -class HttpServer: +# Type of a callback method for processing requests +# it is actually called with a SynapseRequest and a kwargs dict for the params, +# but I can't figure out how to represent that. +ServletCallback = Callable[ + ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]] +] + + +class HttpServer(Protocol): """ Interface for registering callbacks on a HTTP server """ - def register_paths(self, method, path_patterns, callback): + def register_paths( + self, + method: str, + path_patterns: Iterable[Pattern], + callback: ServletCallback, + servlet_classname: str, + ) -> None: """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. @@ -180,12 +206,14 @@ class HttpServer: an unpacked tuple. Args: - method (str): The method to listen to. - path_patterns (list): The regex used to match requests. - callback (function): The function to fire if we receive a matched + method: The HTTP method to listen to. + path_patterns: The regex used to match requests. + callback: The function to fire if we receive a matched request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. - This should return a tuple of (code, response). + This should return either tuple of (code, response), or None. + servlet_classname (str): The name of the handler to be used in prometheus + and opentracing logs. """ pass @@ -354,7 +382,7 @@ class JsonResource(DirectServeJsonResource): def _get_handler_for_request( self, request: SynapseRequest - ) -> Tuple[Callable, str, Dict[str, str]]: + ) -> Tuple[ServletCallback, str, Dict[str, str]]: """Finds a callback method to handle the given request. Returns: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index be938df962..0a561eea60 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService -from synapse.http.server import finish_request +from synapse.handlers.sso import SsoIdentityProvider +from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, @@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self._msc2858_enabled = hs.config.experimental.msc2858_enabled self.auth = hs.get_auth() self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() + self._sso_handler = hs.get_sso_handler() + self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), @@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - # While its valid for us to advertise this login type generally, + sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict + + if self._msc2858_enabled: + sso_flow["org.matrix.msc2858.identity_providers"] = [ + _get_auth_flow_dict_for_idp(idp) + for idp in self._sso_handler.get_identity_providers().values() + ] + + flows.append(sso_flow) + + # While it's valid for us to advertise this login type generally, # synapse currently only gives out these tokens as part of the # SSO login flow. # Generally we don't want to advertise login flows that clients @@ -311,8 +324,20 @@ class LoginRestServlet(RestServlet): return result +def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: + """Return an entry for the login flow dict + + Returns an entry suitable for inclusion in "identity_providers" in the + response to GET /_matrix/client/r0/login + """ + e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict + if idp.idp_icon: + e["icon"] = idp.idp_icon + return e + + class SsoRedirectServlet(RestServlet): - PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they @@ -324,13 +349,31 @@ class SsoRedirectServlet(RestServlet): if hs.config.oidc_enabled: hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() + self._msc2858_enabled = hs.config.experimental.msc2858_enabled + + def register(self, http_server: HttpServer) -> None: + super().register(http_server) + if self._msc2858_enabled: + # expose additional endpoint for MSC2858 support + http_server.register_paths( + "GET", + client_patterns( + "/org.matrix.msc2858/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$", + releases=(), + unstable=True, + ), + self.on_GET, + self.__class__.__name__, + ) - async def on_GET(self, request: SynapseRequest): + async def on_GET( + self, request: SynapseRequest, idp_id: Optional[str] = None + ) -> None: client_redirect_url = parse_string( request, "redirectUrl", required=True, encoding=None ) sso_url = await self._sso_handler.handle_redirect_request( - request, client_redirect_url + request, client_redirect_url, idp_id, ) logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2672ce24c6..e2bb945453 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -75,6 +75,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' # the query params in TEST_CLIENT_REDIRECT_URL EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -426,6 +430,57 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): d["/_synapse/oidc"] = OIDCResource(self.hs) return d + def test_get_login_flows(self): + """GET /login should return password and SSO flows""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.sso"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(channel.json_body["flows"], expected_flows) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_get_msc2858_login_flows(self): + """The SSO flow should include IdP info if MSC2858 is enabled""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + # stick the flows results in a dict by type + flow_results = {} # type: Dict[str, Any] + for f in channel.json_body["flows"]: + flow_type = f["type"] + self.assertNotIn( + flow_type, flow_results, "duplicate flow type %s" % (flow_type,) + ) + flow_results[flow_type] = f + + self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") + sso_flow = flow_results.pop("m.login.sso") + # we should have a set of IdPs + self.assertCountEqual( + sso_flow["org.matrix.msc2858.identity_providers"], + [ + {"id": "cas", "name": "CAS"}, + {"id": "saml", "name": "SAML"}, + {"id": "oidc-idp1", "name": "IDP1"}, + {"id": "oidc", "name": "OIDC"}, + ], + ) + + # the rest of the flows are simple + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(flow_results.values(), expected_flows) + def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker @@ -564,6 +619,43 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_unknown(self): + """If the client tries to pick an unknown IdP, return a 404""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_oidc(self): + """If the client pick a known IdP, redirect to it""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + @staticmethod def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: prefix = key + " = " diff --git a/tests/utils.py b/tests/utils.py index 09614093bc..022223cf24 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,7 +33,6 @@ from synapse.api.room_versions import RoomVersions from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.http.server import HttpServer from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore @@ -351,7 +350,7 @@ def mock_getRawHeaders(headers=None): # This is a mock /resource/ not an entire server -class MockHttpResource(HttpServer): +class MockHttpResource: def __init__(self, prefix=""): self.callbacks = [] # 3-tuple of method/pattern/function self.prefix = prefix -- cgit 1.5.1 From 4b73488e811714089ba447884dccb9b6ae3ac16c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2021 17:39:21 +0000 Subject: Ratelimit 3PID /requestToken API (#9238) --- changelog.d/9238.feature | 1 + docs/sample_config.yaml | 6 +- synapse/config/_base.pyi | 2 +- synapse/config/ratelimiting.py | 13 ++++- synapse/handlers/identity.py | 28 ++++++++++ synapse/rest/client/v2_alpha/account.py | 12 +++- synapse/rest/client/v2_alpha/register.py | 6 ++ tests/rest/client/v2_alpha/test_account.py | 90 ++++++++++++++++++++++++++++-- tests/server.py | 9 ++- tests/unittest.py | 5 ++ tests/utils.py | 1 + 11 files changed, 159 insertions(+), 14 deletions(-) create mode 100644 changelog.d/9238.feature (limited to 'tests/rest') diff --git a/changelog.d/9238.feature b/changelog.d/9238.feature new file mode 100644 index 0000000000..143a3e14f5 --- /dev/null +++ b/changelog.d/9238.feature @@ -0,0 +1 @@ +Add ratelimited to 3PID `/requestToken` API. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c2ccd68f3a..e5b6268087 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -824,6 +824,7 @@ log_config: "CONFDIR/SERVERNAME.log.config" # users are joining rooms the server is already in (this is cheap) vs # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) +# - one for ratelimiting how often a user or IP can attempt to validate a 3PID. # # The defaults are as shown below. # @@ -857,7 +858,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # remote: # per_second: 0.01 # burst_count: 3 - +# +#rc_3pid_validation: +# per_second: 0.003 +# burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 7ed07a801d..70025b5d60 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -54,7 +54,7 @@ class RootConfig: tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig - ratelimit: ratelimiting.RatelimitConfig + ratelimiting: ratelimiting.RatelimitConfig media: repository.ContentRepositoryConfig captcha: captcha.CaptchaConfig voip: voip.VoipConfig diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 14b8836197..76f382527d 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -24,7 +24,7 @@ class RateLimitConfig: defaults={"per_second": 0.17, "burst_count": 3.0}, ): self.per_second = config.get("per_second", defaults["per_second"]) - self.burst_count = config.get("burst_count", defaults["burst_count"]) + self.burst_count = int(config.get("burst_count", defaults["burst_count"])) class FederationRateLimitConfig: @@ -102,6 +102,11 @@ class RatelimitConfig(Config): defaults={"per_second": 0.01, "burst_count": 3}, ) + self.rc_3pid_validation = RateLimitConfig( + config.get("rc_3pid_validation") or {}, + defaults={"per_second": 0.003, "burst_count": 5}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -131,6 +136,7 @@ class RatelimitConfig(Config): # users are joining rooms the server is already in (this is cheap) vs # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) + # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. # # The defaults are as shown below. # @@ -164,7 +170,10 @@ class RatelimitConfig(Config): # remote: # per_second: 0.01 # burst_count: 3 - + # + #rc_3pid_validation: + # per_second: 0.003 + # burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index f61844d688..4f7137539b 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -27,9 +27,11 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config.emailconfig import ThreepidBehaviour from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient +from synapse.http.site import SynapseRequest from synapse.types import JsonDict, Requester from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 @@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler): self._web_client_location = hs.config.invite_client_location + # Ratelimiters for `/requestToken` endpoints. + self._3pid_validation_ratelimiter_ip = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + self._3pid_validation_ratelimiter_address = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + + def ratelimit_request_token_requests( + self, request: SynapseRequest, medium: str, address: str, + ): + """Used to ratelimit requests to `/requestToken` by IP and address. + + Args: + request: The associated request + medium: The type of threepid, e.g. "msisdn" or "email" + address: The actual threepid ID, e.g. the phone number or email address + """ + + self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) + self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] ) -> Optional[JsonDict]: diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 65e68d641b..a84a2fb385 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/password/email/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.datastore = hs.get_datastore() @@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after @@ -379,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) @@ -430,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() self.store = self.hs.get_datastore() @@ -458,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b093183e79..10e1891174 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -205,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index cb87b80e33..177dc476da 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,7 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership -from synapse.api.errors import Codes +from synapse.api.errors import Codes, HttpResponseException from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource @@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_email(self): + """Test that we ratelimit /requestToken for the same email. + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test1@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + def reset(ip): + client_secret = "foobar" + session_id = self._request_token(email, client_secret, ip) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + self.email_attempts.clear() + + # We expect to be able to make three requests before getting rate + # limited. + # + # We change IPs to ensure that we're not being ratelimited due to the + # same IP + reset("127.0.0.1") + reset("127.0.0.2") + reset("127.0.0.3") + + with self.assertRaises(HttpResponseException) as cm: + reset("127.0.0.4") + + self.assertEqual(cm.exception.code, 429) + def test_basic_password_reset_canonicalise_email(self): """Test basic password reset flow Request password reset with different spelling @@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) - def _request_token(self, email, client_secret): + def _request_token(self, email, client_secret, ip="127.0.0.1"): channel = self.make_request( "POST", b"account/password/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, + client_ip=ip, ) - self.assertEquals(200, channel.code, channel.result) + + if channel.code != 200: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"], + ) return channel.json_body["sid"] @@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_address_trim(self): self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_ip(self): + """Tests that adding emails is ratelimited by IP + """ + + # We expect to be able to set three emails before getting ratelimited. + self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) + self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) + self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + + with self.assertRaises(HttpResponseException) as cm: + self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + + self.assertEqual(cm.exception.code, 429) + def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed """ @@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): body["next_link"] = next_link channel = self.make_request("POST", b"account/3pid/email/requestToken", body,) - self.assertEquals(expect_code, channel.code, channel.result) + + if channel.code != expect_code: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"], + ) return channel.json_body.get("sid") @@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def _add_email(self, request_email, expected_email): """Test adding an email to profile """ + previous_email_attempts = len(self.email_attempts) + client_secret = "foobar" session_id = self._request_token(request_email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) link = self._get_link_from_email() self._validate_token(link) @@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"]) + + threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} + self.assertIn(expected_email, threepids) diff --git a/tests/server.py b/tests/server.py index 5a85d5fe7f..6419c445ec 100644 --- a/tests/server.py +++ b/tests/server.py @@ -47,6 +47,7 @@ class FakeChannel: site = attr.ib(type=Site) _reactor = attr.ib() result = attr.ib(type=dict, default=attr.Factory(dict)) + _ip = attr.ib(type=str, default="127.0.0.1") _producer = None @property @@ -120,7 +121,7 @@ class FakeChannel: def getPeer(self): # We give an address so that getClientIP returns a non null entry, # causing us to record the MAU - return address.IPv4Address("TCP", "127.0.0.1", 3423) + return address.IPv4Address("TCP", self._ip, 3423) def getHost(self): return None @@ -196,6 +197,7 @@ def make_request( custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] ] = None, + client_ip: str = "127.0.0.1", ) -> FakeChannel: """ Make a web request using the given method, path and content, and render it @@ -223,6 +225,9 @@ def make_request( will pump the reactor until the the renderer tells the channel the request is finished. + client_ip: The IP to use as the requesting IP. Useful for testing + ratelimiting. + Returns: channel """ @@ -250,7 +255,7 @@ def make_request( if isinstance(content, str): content = content.encode("utf8") - channel = FakeChannel(site, reactor) + channel = FakeChannel(site, reactor, ip=client_ip) req = request(channel) req.content = BytesIO(content) diff --git a/tests/unittest.py b/tests/unittest.py index bbd295687c..767d5d6077 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -386,6 +386,7 @@ class HomeserverTestCase(TestCase): custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] ] = None, + client_ip: str = "127.0.0.1", ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the @@ -410,6 +411,9 @@ class HomeserverTestCase(TestCase): custom_headers: (name, value) pairs to add as request headers + client_ip: The IP to use as the requesting IP. Useful for testing + ratelimiting. + Returns: The FakeChannel object which stores the result of the request. """ @@ -426,6 +430,7 @@ class HomeserverTestCase(TestCase): content_is_form, await_result, custom_headers, + client_ip, ) def setup_test_homeserver(self, *args, **kwargs): diff --git a/tests/utils.py b/tests/utils.py index 022223cf24..68033d7535 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -157,6 +157,7 @@ def default_config(name, parse=False): "local": {"per_second": 10000, "burst_count": 10000}, "remote": {"per_second": 10000, "burst_count": 10000}, }, + "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000}, "saml2_enabled": False, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, -- cgit 1.5.1 From f2c1560eca1e2160087a280261ca78d0708ad721 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2021 16:38:29 +0000 Subject: Ratelimit invites by room and target user (#9258) --- changelog.d/9258.feature | 1 + docs/sample_config.yaml | 10 ++++ synapse/config/ratelimiting.py | 19 +++++++ synapse/federation/federation_client.py | 2 +- synapse/handlers/federation.py | 4 ++ synapse/handlers/room.py | 7 +++ synapse/handlers/room_member.py | 25 ++++++++- tests/handlers/test_federation.py | 93 ++++++++++++++++++++++++++++++++- tests/rest/client/v1/test_rooms.py | 35 +++++++++++++ 9 files changed, 192 insertions(+), 4 deletions(-) create mode 100644 changelog.d/9258.feature (limited to 'tests/rest') diff --git a/changelog.d/9258.feature b/changelog.d/9258.feature new file mode 100644 index 0000000000..0028f42d26 --- /dev/null +++ b/changelog.d/9258.feature @@ -0,0 +1 @@ +Add ratelimits to invites in rooms and to specific users. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 332befd948..7fd35516dc 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -825,6 +825,8 @@ log_config: "CONFDIR/SERVERNAME.log.config" # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. +# - two for ratelimiting how often invites can be sent in a room or to a +# specific user. # # The defaults are as shown below. # @@ -862,6 +864,14 @@ log_config: "CONFDIR/SERVERNAME.log.config" #rc_3pid_validation: # per_second: 0.003 # burst_count: 5 +# +#rc_invites: +# per_room: +# per_second: 0.3 +# burst_count: 10 +# per_user: +# per_second: 0.003 +# burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 76f382527d..def33a60ad 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -107,6 +107,15 @@ class RatelimitConfig(Config): defaults={"per_second": 0.003, "burst_count": 5}, ) + self.rc_invites_per_room = RateLimitConfig( + config.get("rc_invites", {}).get("per_room", {}), + defaults={"per_second": 0.3, "burst_count": 10}, + ) + self.rc_invites_per_user = RateLimitConfig( + config.get("rc_invites", {}).get("per_user", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -137,6 +146,8 @@ class RatelimitConfig(Config): # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. + # - two for ratelimiting how often invites can be sent in a room or to a + # specific user. # # The defaults are as shown below. # @@ -174,6 +185,14 @@ class RatelimitConfig(Config): #rc_3pid_validation: # per_second: 0.003 # burst_count: 5 + # + #rc_invites: + # per_room: + # per_second: 0.3 + # burst_count: 10 + # per_user: + # per_second: 0.003 + # burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d330ae5dbc..40e1451201 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -810,7 +810,7 @@ class FederationClient(FederationBase): "User's homeserver does not support this room version", Codes.UNSUPPORTED_ROOM_VERSION, ) - elif e.code == 403: + elif e.code in (403, 429): raise e.to_synapse_error() else: raise diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b6dc7f99b6..dbdfd56ff5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1617,6 +1617,10 @@ class FederationHandler(BaseHandler): if event.state_key == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + member_handler.ratelimit_invite(event.room_id, event.state_key) + # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a # join dance). diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ee27d99135..07b2187eb1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -126,6 +126,10 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._invite_burst_count = ( + hs.config.ratelimiting.rc_invites_per_room.burst_count + ) + async def upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ) -> str: @@ -662,6 +666,9 @@ class RoomCreationHandler(BaseHandler): invite_3pid_list = [] invite_list = [] + if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count: + raise SynapseError(400, "Cannot invite so many users at once") + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e001e418f9..d335da6f19 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) + self._invites_per_room_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, + ) + self._invites_per_user_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -144,6 +155,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """ raise NotImplementedError() + def ratelimit_invite(self, room_id: str, invitee_user_id: str): + """Ratelimit invites by room and by target user. + """ + self._invites_per_room_limiter.ratelimit(room_id) + self._invites_per_user_limiter.ratelimit(invitee_user_id) + async def _local_membership_update( self, requester: Requester, @@ -387,8 +404,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(403, "This room has been blocked on this server") if effective_membership_state == Membership.INVITE: + target_id = target.to_string() + if ratelimit: + self.ratelimit_invite(room_id, target_id) + # block any attempts to invite the server notices mxid - if target.to_string() == self._server_notices_mxid: + if target_id == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") block_invite = False @@ -412,7 +433,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): block_invite = True if not await self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), room_id + requester.user.to_string(), target_id, room_id ): logger.info("Blocking invite due to spam checker") block_invite = True diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 0b24b89a2e..74503112f5 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -16,7 +16,7 @@ import logging from unittest import TestCase from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.federation.federation_base import event_from_pdu_json @@ -191,6 +191,97 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) + @unittest.override_config( + {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invite_by_room_ratelimit(self): + """Tests that invites from federation in a room are actually rate-limited. + """ + other_server = "otherserver" + other_user = "@otheruser:" + other_server + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + + def create_invite_for(local_user): + return event_from_pdu_json( + { + "type": EventTypes.Member, + "content": {"membership": "invite"}, + "room_id": room_id, + "sender": other_user, + "state_key": local_user, + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + for i in range(3): + self.get_success( + self.handler.on_invite_request( + other_server, + create_invite_for("@user-%d:test" % (i,)), + room_version, + ) + ) + + self.get_failure( + self.handler.on_invite_request( + other_server, create_invite_for("@user-4:test"), room_version, + ), + exc=LimitExceededError, + ) + + @unittest.override_config( + {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invite_by_user_ratelimit(self): + """Tests that invites from federation to a particular user are + actually rate-limited. + """ + other_server = "otherserver" + other_user = "@otheruser:" + other_server + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + def create_invite(): + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + return event_from_pdu_json( + { + "type": EventTypes.Member, + "content": {"membership": "invite"}, + "room_id": room_id, + "sender": other_user, + "state_key": "@user:test", + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + for i in range(3): + event = create_invite() + self.get_success( + self.handler.on_invite_request(other_server, event, event.room_version,) + ) + + event = create_invite() + self.get_failure( + self.handler.on_invite_request(other_server, event, event.room_version,), + exc=LimitExceededError, + ) + def _build_and_send_join_event(self, other_server, other_user, room_id): join_event = self.get_success( self.handler.on_make_join_request(other_server, room_id, other_user) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index d4e3165436..2548b3a80c 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -616,6 +616,41 @@ class RoomMemberStateTestCase(RoomBase): self.assertEquals(json.loads(content), channel.json_body) +class RoomInviteRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_rooms_ratelimit(self): + """Tests that invites in a room are actually rate-limited.""" + room_id = self.helper.create_room_as(self.user_id) + + for i in range(3): + self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,)) + + self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429) + + @unittest.override_config( + {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_users_ratelimit(self): + """Tests that invites to a specific user are actually rate-limited.""" + + for i in range(3): + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red") + + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) + + class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" -- cgit 1.5.1 From f78d07bf005f7212bcc74256721677a3b255ea0e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 13:15:51 +0000 Subject: Split out a separate endpoint to complete SSO registration (#9262) There are going to be a couple of paths to get to the final step of SSO reg, and I want the URL in the browser to consistent. So, let's move the final step onto a separate path, which we redirect to. --- changelog.d/9262.feature | 1 + synapse/app/homeserver.py | 2 + synapse/handlers/sso.py | 81 ++++++++++++++++++++++------ synapse/http/server.py | 7 +++ synapse/rest/synapse/client/pick_username.py | 16 +++--- synapse/rest/synapse/client/sso_register.py | 50 +++++++++++++++++ tests/rest/client/v1/test_login.py | 14 ++++- 7 files changed, 145 insertions(+), 26 deletions(-) create mode 100644 changelog.d/9262.feature create mode 100644 synapse/rest/synapse/client/sso_register.py (limited to 'tests/rest') diff --git a/changelog.d/9262.feature b/changelog.d/9262.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9262.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 57a2f5237c..86d6f73674 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -62,6 +62,7 @@ from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -192,6 +193,7 @@ class SynapseHomeServer(HomeServer): "/_synapse/admin": AdminRestResource(self), "/_synapse/client/pick_username": pick_username_resource(self), "/_synapse/client/pick_idp": PickIdpResource(self), + "/_synapse/client/sso_register": SsoRegisterResource(self), } ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 3308b037d2..50c5ae142a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -21,12 +21,13 @@ import attr from typing_extensions import NoReturn, Protocol from twisted.web.http import Request +from twisted.web.iweb import IRequest from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent -from synapse.http.server import respond_with_html +from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer @@ -141,6 +142,9 @@ class UsernameMappingSession: # expiry time for the session, in milliseconds expiry_time_ms = attr.ib(type=int) + # choices made by the user + chosen_localpart = attr.ib(type=Optional[str], default=None) + # the HTTP cookie used to track the mapping session id USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" @@ -647,6 +651,25 @@ class SsoHandler: ) respond_with_html(request, 200, html) + def get_mapping_session(self, session_id: str) -> UsernameMappingSession: + """Look up the given username mapping session + + If it is not found, raises a SynapseError with an http code of 400 + + Args: + session_id: session to look up + Returns: + active mapping session + Raises: + SynapseError if the session is not found/has expired + """ + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if session: + return session + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + async def check_username_availability( self, localpart: str, session_id: str, ) -> bool: @@ -663,12 +686,7 @@ class SsoHandler: # make sure that there is a valid mapping session, to stop people dictionary- # scanning for accounts - - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + self.get_mapping_session(session_id) logger.info( "[session %s] Checking for availability of username %s", @@ -696,16 +714,33 @@ class SsoHandler: localpart: localpart requested by the user session_id: ID of the username mapping session, extracted from a cookie """ - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + session = self.get_mapping_session(session_id) + + # update the session with the user's choices + session.chosen_localpart = localpart + + # we're done; now we can register the user + respond_with_redirect(request, b"/_synapse/client/sso_register") + + async def register_sso_user(self, request: Request, session_id: str) -> None: + """Called once we have all the info we need to register a new user. - logger.info("[session %s] Registering localpart %s", session_id, localpart) + Does so and serves an HTTP response + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + """ + session = self.get_mapping_session(session_id) + + logger.info( + "[session %s] Registering localpart %s", + session_id, + session.chosen_localpart, + ) attributes = UserAttributes( - localpart=localpart, + localpart=session.chosen_localpart, display_name=session.display_name, emails=session.emails, ) @@ -720,7 +755,12 @@ class SsoHandler: request.getClientIP(), ) - logger.info("[session %s] Registered userid %s", session_id, user_id) + logger.info( + "[session %s] Registered userid %s with attributes %s", + session_id, + user_id, + attributes, + ) # delete the mapping session and the cookie del self._username_mapping_sessions[session_id] @@ -751,3 +791,14 @@ class SsoHandler: for session_id in to_expire: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + + +def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: + """Extract the session ID from the cookie + + Raises a SynapseError if the cookie isn't found + """ + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + return session_id.decode("ascii", errors="replace") diff --git a/synapse/http/server.py b/synapse/http/server.py index d69d579b3a..8249732b27 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -761,6 +761,13 @@ def set_clickjacking_protection_headers(request: Request): request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") +def respond_with_redirect(request: Request, url: bytes) -> None: + """Write a 302 response to the request, if it is still alive.""" + logger.debug("Redirect to %s", url.decode("utf-8")) + request.redirect(url) + finish_request(request) + + def finish_request(request: Request): """ Finish writing the response to the request. diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index d3b6803e65..1bc737bad0 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import TYPE_CHECKING import pkg_resources @@ -20,8 +21,7 @@ from twisted.web.http import Request from twisted.web.resource import Resource from twisted.web.static import File -from synapse.api.errors import SynapseError -from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest @@ -61,12 +61,10 @@ class AvailabilityCheckResource(DirectServeJsonResource): async def _async_render_GET(self, request: Request): localpart = parse_string(request, "username", required=True) - session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) - if not session_id: - raise SynapseError(code=400, msg="missing session_id") + session_id = get_username_mapping_session_cookie_from_request(request) is_available = await self._sso_handler.check_username_availability( - localpart, session_id.decode("ascii", errors="replace") + localpart, session_id ) return 200, {"available": is_available} @@ -79,10 +77,8 @@ class SubmitResource(DirectServeHtmlResource): async def _async_render_POST(self, request: SynapseRequest): localpart = parse_string(request, "username", required=True) - session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) - if not session_id: - raise SynapseError(code=400, msg="missing session_id") + session_id = get_username_mapping_session_cookie_from_request(request) await self._sso_handler.handle_submit_username_request( - request, localpart, session_id.decode("ascii", errors="replace") + request, localpart, session_id ) diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py new file mode 100644 index 0000000000..dfefeb7796 --- /dev/null +++ b/synapse/rest/synapse/client/sso_register.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request +from synapse.http.server import DirectServeHtmlResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class SsoRegisterResource(DirectServeHtmlResource): + """A resource which completes SSO registration + + This resource gets mounted at /_synapse/client/sso_register, and is shown + after we collect username and/or consent for a new SSO user. It (finally) registers + the user, and confirms redirect to the client + """ + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + await self._sso_handler.register_sso_user(request, session_id) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index e2bb945453..f01215ed1c 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource from synapse.types import create_requester from tests import unittest @@ -1215,6 +1216,7 @@ class UsernamePickerTestCase(HomeserverTestCase): d = super().create_resource_dict() d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + d["/_synapse/client/sso_register"] = SsoRegisterResource(self.hs) d["/_synapse/oidc"] = OIDCResource(self.hs) return d @@ -1253,7 +1255,7 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) # Now, submit a username to the username picker, which should serve a redirect - # back to the client + # to the completion page submit_path = picker_url + "/submit" content = urlencode({b"username": b"bobby"}).encode("utf8") chan = self.make_request( @@ -1270,6 +1272,16 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") + + # send a request to the completion page, which should 302 to the client redirectUrl + chan = self.make_request( + "GET", + path=location_headers[0], + custom_headers=[("Cookie", "username_mapping_session=" + session_id)], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location matches the requested redirect URL path, query = location_headers[0].split("?", 1) self.assertEqual(path, "https://x") -- cgit 1.5.1 From 9c715a5f1981891815c124353ba15cf4d17bf9bb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:47:59 +0000 Subject: Fix SSO on workers (#9271) Fixes #8966. * Factor out build_synapse_client_resource_tree Start a function which will mount resources common to all workers. * Move sso init into build_synapse_client_resource_tree ... so that we don't have to do it for each worker * Fix SSO-login-via-a-worker Expose the SSO login endpoints on workers, like the documentation says. * Update workers config for new endpoints Add documentation for endpoints recently added (#8942, #9017, #9262) * remove submit_token from workers endpoints list this *doesn't* work on workers (yet). * changelog * Add a comment about the odd path for SAML2Resource --- changelog.d/9271.bugfix | 1 + docs/workers.md | 18 +++++----- synapse/app/generic_worker.py | 11 +++--- synapse/app/homeserver.py | 18 ++-------- synapse/rest/synapse/client/__init__.py | 49 +++++++++++++++++++++++++- synapse/storage/databases/main/registration.py | 40 ++++++++++----------- tests/rest/client/v1/test_login.py | 15 ++------ tests/rest/client/v2_alpha/test_auth.py | 6 ++-- 8 files changed, 93 insertions(+), 65 deletions(-) create mode 100644 changelog.d/9271.bugfix (limited to 'tests/rest') diff --git a/changelog.d/9271.bugfix b/changelog.d/9271.bugfix new file mode 100644 index 0000000000..ef30c6570f --- /dev/null +++ b/changelog.d/9271.bugfix @@ -0,0 +1 @@ +Fix single-sign-on when the endpoints are routed to synapse workers. diff --git a/docs/workers.md b/docs/workers.md index d01683681f..6b8887de36 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -225,7 +225,6 @@ expressions: ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/ - ^/_synapse/client/password_reset/email/submit_token$ # Registration/login requests ^/_matrix/client/(api/v1|r0|unstable)/login$ @@ -256,25 +255,28 @@ Additionally, the following endpoints should be included if Synapse is configure to use SSO (you only need to include the ones for whichever SSO provider you're using): + # for all SSO providers + ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect + ^/_synapse/client/pick_idp$ + ^/_synapse/client/pick_username + ^/_synapse/client/sso_register$ + # OpenID Connect requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ ^/_synapse/oidc/callback$ # SAML requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ ^/_matrix/saml2/authn_response$ # CAS requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$ ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$ -Note that a HTTP listener with `client` and `federation` resources must be -configured in the `worker_listeners` option in the worker config. - -Ensure that all SSO logins go to a single process (usually the main process). +Ensure that all SSO logins go to a single process. For multiple workers not handling the SSO endpoints properly, see [#7530](https://github.com/matrix-org/synapse/issues/7530). +Note that a HTTP listener with `client` and `federation` resources must be +configured in the `worker_listeners` option in the worker config. + #### Load balancing It is possible to run multiple instances of this worker app, with incoming requests diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index e60988fa4a..516f2464b4 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -22,6 +22,7 @@ from typing import Dict, Iterable, Optional, Set from typing_extensions import ContextManager from twisted.internet import address +from twisted.web.resource import IResource import synapse import synapse.events @@ -90,9 +91,8 @@ from synapse.replication.tcp.streams import ( ToDeviceStream, ) from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client.v1 import events, room +from synapse.rest.client.v1 import events, login, room from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet -from synapse.rest.client.v1.login import LoginRestServlet from synapse.rest.client.v1.profile import ( ProfileAvatarURLRestServlet, ProfileDisplaynameRestServlet, @@ -127,6 +127,7 @@ from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.server import HomeServer, cache_in_self from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore @@ -507,7 +508,7 @@ class GenericWorkerServer(HomeServer): site_tag = port # We always include a health resource. - resources = {"/health": HealthResource()} + resources = {"/health": HealthResource()} # type: Dict[str, IResource] for res in listener_config.http_options.resources: for name in res.names: @@ -517,7 +518,7 @@ class GenericWorkerServer(HomeServer): resource = JsonResource(self, canonical_json=False) RegisterRestServlet(self).register(resource) - LoginRestServlet(self).register(resource) + login.register_servlets(self, resource) ThreepidRestServlet(self).register(resource) DevicesRestServlet(self).register(resource) KeyQueryServlet(self).register(resource) @@ -557,6 +558,8 @@ class GenericWorkerServer(HomeServer): groups.register_servlets(self, resource) resources.update({CLIENT_API_PREFIX: resource}) + + resources.update(build_synapse_client_resource_tree(self)) elif name == "federation": resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) elif name == "media": diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 86d6f73674..244657cb88 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -60,9 +60,7 @@ from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.rest.synapse.client.pick_idp import PickIdpResource -from synapse.rest.synapse.client.pick_username import pick_username_resource -from synapse.rest.synapse.client.sso_register import SsoRegisterResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -191,22 +189,10 @@ class SynapseHomeServer(HomeServer): "/_matrix/client/versions": client_resource, "/.well-known/matrix/client": WellKnownResource(self), "/_synapse/admin": AdminRestResource(self), - "/_synapse/client/pick_username": pick_username_resource(self), - "/_synapse/client/pick_idp": PickIdpResource(self), - "/_synapse/client/sso_register": SsoRegisterResource(self), + **build_synapse_client_resource_tree(self), } ) - if self.get_config().oidc_enabled: - from synapse.rest.oidc import OIDCResource - - resources["/_synapse/oidc"] = OIDCResource(self) - - if self.get_config().saml2_enabled: - from synapse.rest.saml2 import SAML2Resource - - resources["/_matrix/saml2"] = SAML2Resource(self) - if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL: from synapse.rest.synapse.client.password_reset import ( PasswordResetSubmitTokenResource, diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index c0b733488b..6acbc03d73 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2020 The Matrix.org Foundation C.I.C. +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +12,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import TYPE_CHECKING, Mapping + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]: + """Builds a resource tree to include synapse-specific client resources + + These are resources which should be loaded on all workers which expose a C-S API: + ie, the main process, and any generic workers so configured. + + Returns: + map from path to Resource. + """ + resources = { + # SSO bits. These are always loaded, whether or not SSO login is actually + # enabled (they just won't work very well if it's not) + "/_synapse/client/pick_idp": PickIdpResource(hs), + "/_synapse/client/pick_username": pick_username_resource(hs), + "/_synapse/client/sso_register": SsoRegisterResource(hs), + } + + # provider-specific SSO bits. Only load these if they are enabled, since they + # rely on optional dependencies. + if hs.config.oidc_enabled: + from synapse.rest.oidc import OIDCResource + + resources["/_synapse/oidc"] = OIDCResource(hs) + + if hs.config.saml2_enabled: + from synapse.rest.saml2 import SAML2Resource + + # This is mounted under '/_matrix' for backwards-compatibility. + resources["/_matrix/saml2"] = SAML2Resource(hs) + + return resources + + +__all__ = ["build_synapse_client_resource_tree"] diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 8d05288ed4..14c0878d81 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -443,6 +443,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) + async def record_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> None: + """Record a mapping from an external user id to a mxid + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + await self.db_pool.simple_insert( + table="user_external_ids", + values={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="record_user_external_id", + ) + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: @@ -1371,26 +1391,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - async def record_user_external_id( - self, auth_provider: str, external_id: str, user_id: str - ) -> None: - """Record a mapping from an external user id to a mxid - - Args: - auth_provider: identifier for the remote auth provider - external_id: id on that system - user_id: complete mxid that it is mapped to - """ - await self.db_pool.simple_insert( - table="user_external_ids", - values={ - "auth_provider": auth_provider, - "external_id": external_id, - "user_id": user_id, - }, - desc="record_user_external_id", - ) - async def user_set_password_hash( self, user_id: str, password_hash: Optional[str] ) -> None: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index f01215ed1c..ded22a9767 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -29,9 +29,7 @@ from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet -from synapse.rest.synapse.client.pick_idp import PickIdpResource -from synapse.rest.synapse.client.pick_username import pick_username_resource -from synapse.rest.synapse.client.sso_register import SsoRegisterResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import create_requester from tests import unittest @@ -424,11 +422,8 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): return config def create_resource_dict(self) -> Dict[str, Resource]: - from synapse.rest.oidc import OIDCResource - d = super().create_resource_dict() - d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) - d["/_synapse/oidc"] = OIDCResource(self.hs) + d.update(build_synapse_client_resource_tree(self.hs)) return d def test_get_login_flows(self): @@ -1212,12 +1207,8 @@ class UsernamePickerTestCase(HomeserverTestCase): return config def create_resource_dict(self) -> Dict[str, Resource]: - from synapse.rest.oidc import OIDCResource - d = super().create_resource_dict() - d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) - d["/_synapse/client/sso_register"] = SsoRegisterResource(self.hs) - d["/_synapse/oidc"] = OIDCResource(self.hs) + d.update(build_synapse_client_resource_tree(self.hs)) return d def test_username_picker(self): diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index a6488a3d29..3f50c56745 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -22,7 +22,7 @@ from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import auth, devices, register -from synapse.rest.oidc import OIDCResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import JsonDict, UserID from tests import unittest @@ -173,9 +173,7 @@ class UIAuthTests(unittest.HomeserverTestCase): def create_resource_dict(self): resource_dict = super().create_resource_dict() - if HAS_OIDC: - # mount the OIDC resource at /_synapse/oidc - resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) + resource_dict.update(build_synapse_client_resource_tree(self.hs)) return resource_dict def prepare(self, reactor, clock, hs): -- cgit 1.5.1 From 4167494c90bc0477bdf4855a79e81dc81bba1377 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:52:50 +0000 Subject: Replace username picker with a template (#9275) There's some prelimiary work here to pull out the construction of a jinja environment to a separate function. I wanted to load the template at display time rather than load time, so that it's easy to update on the fly. Honestly, I think we should do this with all our templates: the risk of ending up with malformed templates is far outweighed by the improved turnaround time for an admin trying to update them. --- changelog.d/9275.feature | 1 + docs/sample_config.yaml | 32 +++++- synapse/config/_base.py | 39 +------ synapse/config/oidc_config.py | 3 +- synapse/config/sso.py | 33 +++++- synapse/handlers/sso.py | 2 +- .../res/templates/sso_auth_account_details.html | 115 +++++++++++++++++++++ synapse/res/templates/sso_auth_account_details.js | 76 ++++++++++++++ synapse/res/username_picker/index.html | 19 ---- synapse/res/username_picker/script.js | 95 ----------------- synapse/res/username_picker/style.css | 27 ----- synapse/rest/consent/consent_resource.py | 1 + synapse/rest/synapse/client/pick_username.py | 79 ++++++++++---- synapse/util/templates.py | 106 +++++++++++++++++++ tests/rest/client/v1/test_login.py | 5 +- 15 files changed, 429 insertions(+), 204 deletions(-) create mode 100644 changelog.d/9275.feature create mode 100644 synapse/res/templates/sso_auth_account_details.html create mode 100644 synapse/res/templates/sso_auth_account_details.js delete mode 100644 synapse/res/username_picker/index.html delete mode 100644 synapse/res/username_picker/script.js delete mode 100644 synapse/res/username_picker/style.css create mode 100644 synapse/util/templates.py (limited to 'tests/rest') diff --git a/changelog.d/9275.feature b/changelog.d/9275.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9275.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 05506a7787..a6fbcc6080 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1801,7 +1801,8 @@ saml2_config: # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their -# own username. +# own username (see 'sso_auth_account_details.html' in the 'sso' +# section of this file). # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. @@ -1968,6 +1969,35 @@ sso: # # * idp: the 'idp_id' of the chosen IDP. # + # * HTML page to prompt new users to enter a userid and confirm other + # details: 'sso_auth_account_details.html'. This is only shown if the + # SSO implementation (with any user_mapping_provider) does not return + # a localpart. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * idp: details of the SSO Identity Provider that the user logged in + # with: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # + # * user_attributes: an object containing details about the user that + # we received from the IdP. May have the following attributes: + # + # * display_name: the user's display_name + # * emails: a list of email addresses + # + # The template should render a form which submits the following fields: + # + # * username: the localpart of the user's chosen user id + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 94144efc87..35e5594b73 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,18 +18,18 @@ import argparse import errno import os -import time -import urllib.parse from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Callable, Iterable, List, MutableMapping, Optional +from typing import Any, Iterable, List, MutableMapping, Optional import attr import jinja2 import pkg_resources import yaml +from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter + class ConfigError(Exception): """Represents a problem parsing the configuration @@ -248,6 +248,7 @@ class Config: # Search the custom template directory as well search_directories.insert(0, custom_template_directory) + # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) env = jinja2.Environment(loader=loader, autoescape=autoescape) @@ -267,38 +268,6 @@ class Config: return templates -def _format_ts_filter(value: int, format: str): - return time.strftime(format, time.localtime(value / 1000)) - - -def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: - """Create and return a jinja2 filter that converts MXC urls to HTTP - - Args: - public_baseurl: The public, accessible base URL of the homeserver - """ - - def mxc_to_http_filter(value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - server_and_media_id = value[6:] - fragment = None - if "#" in server_and_media_id: - server_and_media_id, fragment = server_and_media_id.split("#", 1) - fragment = "#" + fragment - - params = {"width": width, "height": height, "method": resize_method} - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - public_baseurl, - server_and_media_id, - urllib.parse.urlencode(params), - fragment or "", - ) - - return mxc_to_http_filter - - class RootConfig: """ Holder of an application's configuration. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index f31511e039..784b416f95 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -152,7 +152,8 @@ class OIDCConfig(Config): # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their - # own username. + # own username (see 'sso_auth_account_details.html' in the 'sso' + # section of this file). # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index a470112ed4..e308fc9333 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -27,7 +27,7 @@ class SSOConfig(Config): sso_config = config.get("sso") or {} # type: Dict[str, Any] # The sso-specific template_dir - template_dir = sso_config.get("template_dir") + self.sso_template_dir = sso_config.get("template_dir") # Read templates from disk ( @@ -48,7 +48,7 @@ class SSOConfig(Config): "sso_auth_success.html", "sso_auth_bad_user.html", ], - template_dir, + self.sso_template_dir, ) # These templates have no placeholders, so render them here @@ -124,6 +124,35 @@ class SSOConfig(Config): # # * idp: the 'idp_id' of the chosen IDP. # + # * HTML page to prompt new users to enter a userid and confirm other + # details: 'sso_auth_account_details.html'. This is only shown if the + # SSO implementation (with any user_mapping_provider) does not return + # a localpart. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * idp: details of the SSO Identity Provider that the user logged in + # with: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # + # * user_attributes: an object containing details about the user that + # we received from the IdP. May have the following attributes: + # + # * display_name: the user's display_name + # * emails: a list of email addresses + # + # The template should render a form which submits the following fields: + # + # * username: the localpart of the user's chosen user id + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index ceaeb5a376..ff4750999a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -530,7 +530,7 @@ class SsoHandler: logger.info("Recorded registration session id %s", session_id) # Set the cookie and redirect to the username picker - e = RedirectException(b"/_synapse/client/pick_username") + e = RedirectException(b"/_synapse/client/pick_username/account_details") e.cookies.append( b"%s=%s; path=/" % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html new file mode 100644 index 0000000000..f22b09aec1 --- /dev/null +++ b/synapse/res/templates/sso_auth_account_details.html @@ -0,0 +1,115 @@ + + + + Synapse Login + + + + + +
+

Your account is nearly ready

+

Check your details before creating an account on {{ server_name }}

+
+
+
+
+ +
@
+ +
:{{ server_name }}
+
+ + {% if user_attributes %} +
+

Information from {{ idp.idp_name }}

+ {% if user_attributes.avatar_url %} +
+ +
+ {% endif %} + {% if user_attributes.display_name %} +
+

{{ user_attributes.display_name }}

+
+ {% endif %} + {% for email in user_attributes.emails %} +
+

{{ email }}

+
+ {% endfor %} +
+ {% endif %} +
+
+ + + diff --git a/synapse/res/templates/sso_auth_account_details.js b/synapse/res/templates/sso_auth_account_details.js new file mode 100644 index 0000000000..deef419bb6 --- /dev/null +++ b/synapse/res/templates/sso_auth_account_details.js @@ -0,0 +1,76 @@ +const usernameField = document.getElementById("field-username"); + +function throttle(fn, wait) { + let timeout; + return function() { + const args = Array.from(arguments); + if (timeout) { + clearTimeout(timeout); + } + timeout = setTimeout(fn.bind.apply(fn, [null].concat(args)), wait); + } +} + +function checkUsernameAvailable(username) { + let check_uri = 'check?username=' + encodeURIComponent(username); + return fetch(check_uri, { + // include the cookie + "credentials": "same-origin", + }).then((response) => { + if(!response.ok) { + // for non-200 responses, raise the body of the response as an exception + return response.text().then((text) => { throw new Error(text); }); + } else { + return response.json(); + } + }).then((json) => { + if(json.error) { + return {message: json.error}; + } else if(json.available) { + return {available: true}; + } else { + return {message: username + " is not available, please choose another."}; + } + }); +} + +function validateUsername(username) { + usernameField.setCustomValidity(""); + if (usernameField.validity.valueMissing) { + usernameField.setCustomValidity("Please provide a username"); + return; + } + if (usernameField.validity.patternMismatch) { + usernameField.setCustomValidity("Invalid username, please only use " + allowedCharactersString); + return; + } + usernameField.setCustomValidity("Checking if username is available …"); + throttledCheckUsernameAvailable(username); +} + +const throttledCheckUsernameAvailable = throttle(function(username) { + const handleError = function(err) { + // don't prevent form submission on error + usernameField.setCustomValidity(""); + console.log(err.message); + }; + try { + checkUsernameAvailable(username).then(function(result) { + if (!result.available) { + usernameField.setCustomValidity(result.message); + usernameField.reportValidity(); + } else { + usernameField.setCustomValidity(""); + } + }, handleError); + } catch (err) { + handleError(err); + } +}, 500); + +usernameField.addEventListener("input", function(evt) { + validateUsername(usernameField.value); +}); +usernameField.addEventListener("change", function(evt) { + validateUsername(usernameField.value); +}); diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html deleted file mode 100644 index 37ea8bb6d8..0000000000 --- a/synapse/res/username_picker/index.html +++ /dev/null @@ -1,19 +0,0 @@ - - - - Synapse Login - - - -
-
- - - -
- - - -
- - diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js deleted file mode 100644 index 416a7c6f41..0000000000 --- a/synapse/res/username_picker/script.js +++ /dev/null @@ -1,95 +0,0 @@ -let inputField = document.getElementById("field-username"); -let inputForm = document.getElementById("form"); -let submitButton = document.getElementById("button-submit"); -let message = document.getElementById("message"); - -// Submit username and receive response -function showMessage(messageText) { - // Unhide the message text - message.classList.remove("hidden"); - - message.textContent = messageText; -}; - -function doSubmit() { - showMessage("Success. Please wait a moment for your browser to redirect."); - - // remove the event handler before re-submitting the form. - delete inputForm.onsubmit; - inputForm.submit(); -} - -function onResponse(response) { - // Display message - showMessage(response); - - // Enable submit button and input field - submitButton.classList.remove('button--disabled'); - submitButton.value = "Submit"; -}; - -let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]"); -function usernameIsValid(username) { - return !allowedUsernameCharacters.test(username); -} -let allowedCharactersString = "lowercase letters, digits, ., _, -, /, ="; - -function buildQueryString(params) { - return Object.keys(params) - .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k])) - .join('&'); -} - -function submitUsername(username) { - if(username.length == 0) { - onResponse("Please enter a username."); - return; - } - if(!usernameIsValid(username)) { - onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString); - return; - } - - // if this browser doesn't support fetch, skip the availability check. - if(!window.fetch) { - doSubmit(); - return; - } - - let check_uri = 'check?' + buildQueryString({"username": username}); - fetch(check_uri, { - // include the cookie - "credentials": "same-origin", - }).then((response) => { - if(!response.ok) { - // for non-200 responses, raise the body of the response as an exception - return response.text().then((text) => { throw text; }); - } else { - return response.json(); - } - }).then((json) => { - if(json.error) { - throw json.error; - } else if(json.available) { - doSubmit(); - } else { - onResponse("This username is not available, please choose another."); - } - }).catch((err) => { - onResponse("Error checking username availability: " + err); - }); -} - -function clickSubmit() { - event.preventDefault(); - if(submitButton.classList.contains('button--disabled')) { return; } - - // Disable submit button and input field - submitButton.classList.add('button--disabled'); - - // Submit username - submitButton.value = "Checking..."; - submitUsername(inputField.value); -}; - -inputForm.onsubmit = clickSubmit; diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css deleted file mode 100644 index 745bd4c684..0000000000 --- a/synapse/res/username_picker/style.css +++ /dev/null @@ -1,27 +0,0 @@ -input[type="text"] { - font-size: 100%; - background-color: #ededf0; - border: 1px solid #fff; - border-radius: .2em; - padding: .5em .9em; - display: block; - width: 26em; -} - -.button--disabled { - border-color: #fff; - background-color: transparent; - color: #000; - text-transform: none; -} - -.hidden { - display: none; -} - -.tooltip { - background-color: #f9f9fa; - padding: 1em; - margin: 1em 0; -} - diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index b3e4d5612e..8b9ef26cf2 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource): consent_template_directory = hs.config.user_consent_template_dir + # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(consent_template_directory) self._jinja_env = jinja2.Environment( loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"]) diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 1bc737bad0..27540d3bbe 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -13,41 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import TYPE_CHECKING -import pkg_resources - from twisted.web.http import Request from twisted.web.resource import Resource -from twisted.web.static import File +from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request -from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + respond_with_html, +) from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest +from synapse.util.templates import build_jinja_env if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + def pick_username_resource(hs: "HomeServer") -> Resource: """Factory method to generate the username picker resource. - This resource gets mounted under /_synapse/client/pick_username. The top-level - resource is just a File resource which serves up the static files in the resources - "res" directory, but it has a couple of children: + This resource gets mounted under /_synapse/client/pick_username and has two + children: - * "submit", which does the mechanics of registering the new user, and redirects the - browser back to the client URL - - * "check": checks if a userid is free. + * "account_details": renders the form and handles the POSTed response + * "check": a JSON endpoint which checks if a userid is free. """ - # XXX should we make this path customisable so that admins can restyle it? - base_path = pkg_resources.resource_filename("synapse", "res/username_picker") - - res = File(base_path) - res.putChild(b"submit", SubmitResource(hs)) + res = Resource() + res.putChild(b"account_details", AccountDetailsResource(hs)) res.putChild(b"check", AvailabilityCheckResource(hs)) return res @@ -69,15 +69,54 @@ class AvailabilityCheckResource(DirectServeJsonResource): return 200, {"available": is_available} -class SubmitResource(DirectServeHtmlResource): +class AccountDetailsResource(DirectServeHtmlResource): def __init__(self, hs: "HomeServer"): super().__init__() self._sso_handler = hs.get_sso_handler() - async def _async_render_POST(self, request: SynapseRequest): - localpart = parse_string(request, "username", required=True) + def template_search_dirs(): + if hs.config.sso.sso_template_dir: + yield hs.config.sso.sso_template_dir + yield hs.config.sso.default_template_dir + + self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + session = self._sso_handler.get_mapping_session(session_id) + except SynapseError as e: + logger.warning("Error fetching session: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + idp_id = session.auth_provider_id + template_params = { + "idp": self._sso_handler.get_identity_providers()[idp_id], + "user_attributes": { + "display_name": session.display_name, + "emails": session.emails, + }, + } + + template = self._jinja_env.get_template("sso_auth_account_details.html") + html = template.render(template_params) + respond_with_html(request, 200, html) - session_id = get_username_mapping_session_cookie_from_request(request) + async def _async_render_POST(self, request: SynapseRequest): + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + try: + localpart = parse_string(request, "username", required=True) + except SynapseError as e: + logger.warning("[session %s] bad param: %s", session_id, e) + self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code) + return await self._sso_handler.handle_submit_username_request( request, localpart, session_id diff --git a/synapse/util/templates.py b/synapse/util/templates.py new file mode 100644 index 0000000000..7e5109d206 --- /dev/null +++ b/synapse/util/templates.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for dealing with jinja2 templates""" + +import time +import urllib.parse +from typing import TYPE_CHECKING, Callable, Iterable, Union + +import jinja2 + +if TYPE_CHECKING: + from synapse.config.homeserver import HomeServerConfig + + +def build_jinja_env( + template_search_directories: Iterable[str], + config: "HomeServerConfig", + autoescape: Union[bool, Callable[[str], bool], None] = None, +) -> jinja2.Environment: + """Set up a Jinja2 environment to load templates from the given search path + + The returned environment defines the following filters: + - format_ts: formats timestamps as strings in the server's local timezone + (XXX: why is that useful??) + - mxc_to_http: converts mxc: uris to http URIs. Args are: + (uri, width, height, resize_method="crop") + + and the following global variables: + - server_name: matrix server name + + Args: + template_search_directories: directories to search for templates + + config: homeserver config, for things like `server_name` and `public_baseurl` + + autoescape: whether template variables should be autoescaped. bool, or + a function mapping from template name to bool. Defaults to escaping templates + whose names end in .html, .xml or .htm. + + Returns: + jinja environment + """ + + if autoescape is None: + autoescape = jinja2.select_autoescape() + + loader = jinja2.FileSystemLoader(template_search_directories) + env = jinja2.Environment(loader=loader, autoescape=autoescape) + + # Update the environment with our custom filters + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl), + } + ) + + # common variables for all templates + env.globals.update({"server_name": config.server_name}) + + return env + + +def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: + """Create and return a jinja2 filter that converts MXC urls to HTTP + + Args: + public_baseurl: The public, accessible base URL of the homeserver + """ + + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + server_and_media_id = value[6:] + fragment = None + if "#" in server_and_media_id: + server_and_media_id, fragment = server_and_media_id.split("#", 1) + fragment = "#" + fragment + + params = {"width": width, "height": height, "method": resize_method} + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + public_baseurl, + server_and_media_id, + urllib.parse.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter + + +def _format_ts_filter(value: int, format: str): + return time.strftime(format, time.localtime(value / 1000)) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index ded22a9767..66dfdaffbc 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1222,7 +1222,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # that should redirect to the username picker self.assertEqual(channel.code, 302, channel.result) picker_url = channel.headers.getRawHeaders("Location")[0] - self.assertEqual(picker_url, "/_synapse/client/pick_username") + self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie cookies = {} # type: Dict[str,str] @@ -1247,11 +1247,10 @@ class UsernamePickerTestCase(HomeserverTestCase): # Now, submit a username to the username picker, which should serve a redirect # to the completion page - submit_path = picker_url + "/submit" content = urlencode({b"username": b"bobby"}).encode("utf8") chan = self.make_request( "POST", - path=submit_path, + path=picker_url, content=content, content_is_form=True, custom_headers=[ -- cgit 1.5.1 From b60bb28bbc3d916586a913970298baba483efc1f Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Tue, 2 Feb 2021 04:16:29 -0700 Subject: Add an admin API to get the current room state (#9168) This could arguably replace the existing admin API for `/members`, however that is out of scope of this change. This sort of endpoint is ideal for moderation use cases as well as other applications, such as needing to retrieve various bits of information about a room to perform a task (like syncing power levels between two places). This endpoint exposes nothing more than an admin would be able to access with a `select *` query on their database. --- changelog.d/9168.feature | 1 + docs/admin_api/rooms.md | 30 ++++++++++++++++++++++++++++++ synapse/handlers/message.py | 2 +- synapse/rest/admin/__init__.py | 2 ++ synapse/rest/admin/rooms.py | 39 +++++++++++++++++++++++++++++++++++++++ tests/rest/admin/test_room.py | 15 +++++++++++++++ 6 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9168.feature (limited to 'tests/rest') diff --git a/changelog.d/9168.feature b/changelog.d/9168.feature new file mode 100644 index 0000000000..8be1950eee --- /dev/null +++ b/changelog.d/9168.feature @@ -0,0 +1 @@ +Add an admin API for retrieving the current room state of a room. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index f34cec1ff7..3832b36407 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -368,6 +368,36 @@ Response: } ``` +# Room State API + +The Room State admin API allows server admins to get a list of all state events in a room. + +The response includes the following fields: + +* `state` - The current state of the room at the time of request. + +## Usage + +A standard request: + +``` +GET /_synapse/admin/v1/rooms//state + +{} +``` + +Response: + +```json +{ + "state": [ + {"type": "m.room.create", "state_key": "", "etc": true}, + {"type": "m.room.power_levels", "state_key": "", "etc": true}, + {"type": "m.room.name", "state_key": "", "etc": true} + ] +} +``` + # Delete Room API The Delete Room admin API allows server admins to remove rooms from server diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e2a7d567fa..a15336bf00 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -174,7 +174,7 @@ class MessageHandler: raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = await filter_events_for_client( - self.storage, user_id, last_events, filter_send_to_client=False + self.storage, user_id, last_events, filter_send_to_client=False, ) event = last_events[0] diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 57e0a10837..f5c5d164f9 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.admin.rooms import ( MakeRoomAdminRestServlet, RoomMembersRestServlet, RoomRestServlet, + RoomStateRestServlet, ShutdownRoomRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -213,6 +214,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f14915d47e..3e57e6a4d0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -292,6 +292,45 @@ class RoomMembersRestServlet(RestServlet): return 200, ret +class RoomStateRestServlet(RestServlet): + """ + Get full state within a room. + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)/state") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + ret = await self.store.get_room(room_id) + if not ret: + raise NotFoundError("Room not found") + + event_ids = await self.store.get_current_state_ids(room_id) + events = await self.store.get_events(event_ids.values()) + now = self.clock.time_msec() + room_state = await self._event_serializer.serialize_events( + events.values(), + now, + # We don't bother bundling aggregations in when asked for state + # events, as clients won't use them. + bundle_aggregations=False, + ) + ret = {"state": room_state} + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P[^/]*)") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index a0f32c5512..7c47aa7e0a 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1180,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["total"], 3) + def test_room_state(self): + """Test that room state can be requested correctly""" + # Create two test rooms + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,) + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("state", channel.json_body) + # testing that the state events match is painful and not done here. We assume that + # the create_room already does the right thing, so no need to verify that we got + # the state events it created. + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1