From 0a363f9ca4b71187eb26b80dfa2cd72a35b1f8fd Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 22 Feb 2021 16:52:45 +0000 Subject: Remove cache for get_shared_rooms_for_users (#9416) This PR remove the cache for the `get_shared_rooms_for_users` storage method (the db method driving the experimental "what rooms do I share with this user?" feature: [MSC2666](https://github.com/matrix-org/matrix-doc/pull/2666)). Currently subsequent requests to the endpoint will return the same result, even if your shared rooms with that user have changed. The cache was added in https://github.com/matrix-org/synapse/pull/7785, but we forgot to ensure it was invalidated appropriately. Upon attempting to invalidate it, I found that the cache had to be entirely invalidated whenever a user (remote or local) joined or left a room. This didn't make for a very useful cache, especially for a function that may or may not be called very often. Thus, I've opted to remove it instead of invalidating it. --- tests/rest/client/v2_alpha/test_shared_rooms.py | 75 ++++++++++++++----------- 1 file changed, 41 insertions(+), 34 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py index 116ace1812..dd83a1f8ff 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -54,61 +54,62 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): A room should show up in the shared list of rooms between two users if it is public. """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) def test_shared_room_list_private(self): """ A room should show up in the shared list of rooms between two users if it is private. """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) + self._check_shared_rooms_with( + room_one_is_public=False, room_two_is_public=False + ) def test_shared_room_list_mixed(self): """ The shared room list between two users should contain both public and private rooms. """ + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) + + def _check_shared_rooms_with( + self, room_one_is_public: bool, room_two_is_public: bool + ): + """Checks that shared public or private rooms between two users appear in + their shared room lists + """ u1 = self.register_user("user1", "pass") u1_token = self.login(u1, "pass") u2 = self.register_user("user2", "pass") u2_token = self.login(u2, "pass") - room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token) - self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token) - self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token) - self.helper.join(room_public, user=u2, tok=u2_token) - self.helper.join(room_private, user=u1, tok=u1_token) + # Create a room. user1 invites user2, who joins + room_id_one = self.helper.create_room_as( + u1, is_public=room_one_is_public, tok=u1_token + ) + self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_one, user=u2, tok=u2_token) + # Check shared rooms from user1's perspective. + # We should see the one room in common + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room_id_one) + + # Create another room and invite user2 to it + room_id_two = self.helper.create_room_as( + u1, is_public=room_two_is_public, tok=u1_token + ) + self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_two, user=u2, tok=u2_token) + + # Check shared rooms again. We should now see both rooms. channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 2) - self.assertTrue(room_public in channel.json_body["joined"]) - self.assertTrue(room_private in channel.json_body["joined"]) + for room_id_id in channel.json_body["joined"]: + self.assertIn(room_id_id, [room_id_one, room_id_two]) def test_shared_room_list_after_leave(self): """ @@ -132,6 +133,12 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.leave(room, user=u1, tok=u1_token) + # Check user1's view of shared rooms with user2 + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0) + + # Check user2's view of shared rooms with user1 channel = self._get_shared_rooms(u2_token, u1) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 0) -- cgit 1.5.1 From 71c9f8de6d9fb5323c8625167f716e1ec381cb87 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 22 Feb 2021 20:38:51 +0100 Subject: Add an `order_by` field to list users' media admin API. (#8978) --- changelog.d/8978.feature | 1 + docs/admin_api/user_admin_api.rst | 38 +++- synapse/rest/admin/users.py | 28 ++- synapse/storage/databases/main/media_repository.py | 41 +++- tests/rest/admin/test_user.py | 246 +++++++++++++++++++-- 5 files changed, 325 insertions(+), 29 deletions(-) create mode 100644 changelog.d/8978.feature (limited to 'tests') diff --git a/changelog.d/8978.feature b/changelog.d/8978.feature new file mode 100644 index 0000000000..042e257bf0 --- /dev/null +++ b/changelog.d/8978.feature @@ -0,0 +1 @@ +Add `order_by` to the admin API `GET /_synapse/admin/v1/users//media`. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 33dfbcfb49..8d4ec5a6f9 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -379,11 +379,12 @@ The following fields are returned in the JSON response body: - ``total`` - Number of rooms. -List media of an user -================================ +List media of a user +==================== Gets a list of all local media that a specific ``user_id`` has created. -The response is ordered by creation date descending and media ID descending. -The newest media is on top. +By default, the response is ordered by descending creation date and ascending media ID. +The newest media is on top. You can change the order with parameters +``order_by`` and ``dir``. The API is:: @@ -440,6 +441,35 @@ The following parameters should be set in the URL: denoting the offset in the returned results. This should be treated as an opaque value and not explicitly set to anything other than the return value of ``next_token`` from a previous call. Defaults to ``0``. +- ``order_by`` - The method by which to sort the returned list of media. + If the ordered field has duplicates, the second order is always by ascending ``media_id``, + which guarantees a stable ordering. Valid values are: + + - ``media_id`` - Media are ordered alphabetically by ``media_id``. + - ``upload_name`` - Media are ordered alphabetically by name the media was uploaded with. + - ``created_ts`` - Media are ordered by when the content was uploaded in ms. + Smallest to largest. This is the default. + - ``last_access_ts`` - Media are ordered by when the content was last accessed in ms. + Smallest to largest. + - ``media_length`` - Media are ordered by length of the media in bytes. + Smallest to largest. + - ``media_type`` - Media are ordered alphabetically by MIME-type. + - ``quarantined_by`` - Media are ordered alphabetically by the user ID that + initiated the quarantine request for this media. + - ``safe_from_quarantine`` - Media are ordered by the status if this media is safe + from quarantining. + +- ``dir`` - Direction of media order. Either ``f`` for forwards or ``b`` for backwards. + Setting this value to ``b`` will reverse the above sort order. Defaults to ``f``. + +If neither ``order_by`` nor ``dir`` is set, the default order is newest media on top +(corresponds to ``order_by`` = ``created_ts`` and ``dir`` = ``b``). + +Caution. The database only has indexes on the columns ``media_id``, +``user_id`` and ``created_ts``. This means that if a different sort order is used +(``upload_name``, ``last_access_ts``, ``media_length``, ``media_type``, +``quarantined_by`` or ``safe_from_quarantine``), this can cause a large load on the +database, especially for large environments. **Response** diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 998a0ef671..9c701c7348 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -35,6 +35,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.storage.databases.main.media_repository import MediaSortOrder from synapse.types import JsonDict, UserID if TYPE_CHECKING: @@ -832,8 +833,33 @@ class UserMediaRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) + # If neither `order_by` nor `dir` is set, set the default order + # to newest media is on top for backward compatibility. + if b"order_by" not in request.args and b"dir" not in request.args: + order_by = MediaSortOrder.CREATED_TS.value + direction = "b" + else: + order_by = parse_string( + request, + "order_by", + default=MediaSortOrder.CREATED_TS.value, + allowed_values=( + MediaSortOrder.MEDIA_ID.value, + MediaSortOrder.UPLOAD_NAME.value, + MediaSortOrder.CREATED_TS.value, + MediaSortOrder.LAST_ACCESS_TS.value, + MediaSortOrder.MEDIA_LENGTH.value, + MediaSortOrder.MEDIA_TYPE.value, + MediaSortOrder.QUARANTINED_BY.value, + MediaSortOrder.SAFE_FROM_QUARANTINE.value, + ), + ) + direction = parse_string( + request, "dir", default="f", allowed_values=("f", "b") + ) + media, total = await self.store.get_local_media_by_user_paginate( - start, limit, user_id + start, limit, user_id, order_by, direction ) ret = {"media": media, "total": total} diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 9ee642c668..274f8de595 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from typing import Any, Dict, Iterable, List, Optional, Tuple from synapse.storage._base import SQLBaseStore @@ -23,6 +24,22 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( ) +class MediaSortOrder(Enum): + """ + Enum to define the sorting method used when returning media with + get_local_media_by_user_paginate + """ + + MEDIA_ID = "media_id" + UPLOAD_NAME = "upload_name" + CREATED_TS = "created_ts" + LAST_ACCESS_TS = "last_access_ts" + MEDIA_LENGTH = "media_length" + MEDIA_TYPE = "media_type" + QUARANTINED_BY = "quarantined_by" + SAFE_FROM_QUARANTINE = "safe_from_quarantine" + + class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -118,7 +135,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def get_local_media_by_user_paginate( - self, start: int, limit: int, user_id: str + self, + start: int, + limit: int, + user_id: str, + order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value, + direction: str = "f", ) -> Tuple[List[Dict[str, Any]], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -127,6 +149,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): start: offset in the list limit: maximum amount of media_ids to retrieve user_id: fully-qualified user id + order_by: the sort order of the returned list + direction: sort ascending or descending Returns: A paginated list of all metadata of user's media, plus the total count of all the user's media @@ -134,6 +158,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn(txn): + # Set ordering + order_by_column = MediaSortOrder(order_by).value + + if direction == "b": + order = "DESC" + else: + order = "ASC" + args = [user_id] sql = """ SELECT COUNT(*) as total_media @@ -155,9 +187,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "safe_from_quarantine" FROM local_media_repository WHERE user_id = ? - ORDER BY created_ts DESC, media_id DESC + ORDER BY {order_by_column} {order}, media_id ASC LIMIT ? OFFSET ? - """ + """.format( + order_by_column=order_by_column, + order=order, + ) args += [limit, start] txn.execute(sql, args) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ba26895391..e58d5cf0db 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -18,7 +18,7 @@ import hmac import json import urllib.parse from binascii import unhexlify -from typing import Optional +from typing import List, Optional from mock import Mock @@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import devices, sync from synapse.types import JsonDict from tests import unittest +from tests.server import FakeSite, make_request from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -1954,6 +1955,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -2024,7 +2026,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2045,7 +2047,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2066,7 +2068,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2080,11 +2082,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["media"]), 10) self._check_fields(channel.json_body["media"]) - def test_limit_is_negative(self): + def test_invalid_parameter(self): """ - Testing that a negative limit parameter returns a 400 + If parameters are invalid, an error is returned. """ + # unkown order_by + channel = self.make_request( + "GET", + self.url + "?order_by=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + # invalid search order + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # negative limit channel = self.make_request( "GET", self.url + "?limit=-5", @@ -2094,11 +2116,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_from_is_negative(self): - """ - Testing that a negative from parameter returns a 400 - """ - + # negative from channel = self.make_request( "GET", self.url + "?from=-5", @@ -2115,7 +2133,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) # `next_token` does not appear # Number of results is the number of entries @@ -2193,7 +2211,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 5 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2207,11 +2225,118 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) - def _create_media(self, user_token, number_media): + def test_order_by(self): + """ + Testing order list with parameter `order_by` + """ + + other_user_tok = self.login("user", "pass") + + # Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B + image_data1 = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + # Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B + image_data2 = unhexlify( + b"47494638376101000100800100000000" + b"ffffff2c00000000010001000002024c" + b"01003b" + ) + # Resolution: 1×1, MIME type: image/bmp, Extension: bmp, Size: 54 B + image_data3 = unhexlify( + b"424d3a0000000000000036000000280000000100000001000000" + b"0100180000000000040000000000000000000000000000000000" + b"0000" + ) + + # create media and make sure they do not have the same timestamp + media1 = self._create_media_and_access(other_user_tok, image_data1, "image.png") + self.pump(1.0) + media2 = self._create_media_and_access(other_user_tok, image_data2, "image.gif") + self.pump(1.0) + media3 = self._create_media_and_access(other_user_tok, image_data3, "image.bmp") + self.pump(1.0) + + # Mark one media as safe from quarantine. + self.get_success(self.store.mark_local_media_as_safe(media2)) + # Quarantine one media + self.get_success( + self.store.quarantine_media_by_id("test", media3, self.admin_user) + ) + + # order by default ("created_ts") + # default is backwards + self._order_test([media3, media2, media1], None) + self._order_test([media1, media2, media3], None, "f") + self._order_test([media3, media2, media1], None, "b") + + # sort by media_id + sorted_media = sorted([media1, media2, media3], reverse=False) + sorted_media_reverse = sorted(sorted_media, reverse=True) + + # order by media_id + self._order_test(sorted_media, "media_id") + self._order_test(sorted_media, "media_id", "f") + self._order_test(sorted_media_reverse, "media_id", "b") + + # order by upload_name + self._order_test([media3, media2, media1], "upload_name") + self._order_test([media3, media2, media1], "upload_name", "f") + self._order_test([media1, media2, media3], "upload_name", "b") + + # order by media_type + # result is ordered by media_id + # because of uploaded media_type is always 'application/json' + self._order_test(sorted_media, "media_type") + self._order_test(sorted_media, "media_type", "f") + self._order_test(sorted_media, "media_type", "b") + + # order by media_length + self._order_test([media2, media3, media1], "media_length") + self._order_test([media2, media3, media1], "media_length", "f") + self._order_test([media1, media3, media2], "media_length", "b") + + # order by created_ts + self._order_test([media1, media2, media3], "created_ts") + self._order_test([media1, media2, media3], "created_ts", "f") + self._order_test([media3, media2, media1], "created_ts", "b") + + # order by last_access_ts + self._order_test([media1, media2, media3], "last_access_ts") + self._order_test([media1, media2, media3], "last_access_ts", "f") + self._order_test([media3, media2, media1], "last_access_ts", "b") + + # order by quarantined_by + # one media is in quarantine, others are ordered by media_ids + + # Different sort order of SQlite and PostreSQL + # If a media is not in quarantine `quarantined_by` is NULL + # SQLite considers NULL to be smaller than any other value. + # PostreSQL considers NULL to be larger than any other value. + + # self._order_test(sorted([media1, media2]) + [media3], "quarantined_by") + # self._order_test(sorted([media1, media2]) + [media3], "quarantined_by", "f") + # self._order_test([media3] + sorted([media1, media2]), "quarantined_by", "b") + + # order by safe_from_quarantine + # one media is safe from quarantine, others are ordered by media_ids + self._order_test(sorted([media1, media3]) + [media2], "safe_from_quarantine") + self._order_test( + sorted([media1, media3]) + [media2], "safe_from_quarantine", "f" + ) + self._order_test( + [media2] + sorted([media1, media3]), "safe_from_quarantine", "b" + ) + + def _create_media_for_user(self, user_token: str, number_media: int): """ Create a number of media for a specific user + Args: + user_token: Access token of the user + number_media: Number of media to be created for the user """ - upload_resource = self.media_repo.children[b"upload"] for i in range(number_media): # file size is 67 Byte image_data = unhexlify( @@ -2220,13 +2345,60 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): b"0a2db40000000049454e44ae426082" ) - # Upload some media into the room - self.helper.upload_media( - upload_resource, image_data, tok=user_token, expect_code=200 - ) + self._create_media_and_access(user_token, image_data) + + def _create_media_and_access( + self, + user_token: str, + image_data: bytes, + filename: str = "image1.png", + ) -> str: + """ + Create one media for a specific user, access and returns `media_id` + Args: + user_token: Access token of the user + image_data: binary data of image + filename: The filename of the media to be uploaded + Returns: + The ID of the newly created media. + """ + upload_resource = self.media_repo.children[b"upload"] + download_resource = self.media_repo.children[b"download"] + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, user_token, filename, expect_code=200 + ) + + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + media_id = server_and_media_id.split("/")[1] + + # Try to access a media and to create `last_access_ts` + channel = make_request( + self.reactor, + FakeSite(download_resource), + "GET", + server_and_media_id, + shorthand=False, + access_token=user_token, + ) + + self.assertEqual( + 200, + channel.code, + msg=( + "Expected to receive a 200 on accessing media: %s" % server_and_media_id + ), + ) - def _check_fields(self, content): - """Checks that all attributes are present in content""" + return media_id + + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + Args: + content: List that is checked for content + """ for m in content: self.assertIn("media_id", m) self.assertIn("media_type", m) @@ -2237,6 +2409,38 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertIn("quarantined_by", m) self.assertIn("safe_from_quarantine", m) + def _order_test( + self, + expected_media_list: List[str], + order_by: Optional[str], + dir: Optional[str] = None, + ): + """Request the list of media in a certain order. Assert that order is what + we expect + Args: + expected_media_list: The list of media_ids in the order we expect to get + back from the server + order_by: The type of ordering to give the server + dir: The direction of ordering to give the server + """ + + url = self.url + "?" + if order_by is not None: + url += "order_by=%s&" % (order_by,) + if dir is not None and dir in ("b", "f"): + url += "dir=%s" % (dir,) + channel = self.make_request( + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], len(expected_media_list)) + + returned_order = [row["media_id"] for row in channel.json_body["media"]] + self.assertEqual(expected_media_list, returned_order) + self._check_fields(channel.json_body["media"]) + class UserTokenRestTestCase(unittest.HomeserverTestCase): """Test for /_synapse/admin/v1/users//login""" -- cgit 1.5.1 From 1b2d6d55c5f4e68053705c46a968e0b2815ab5de Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 22 Feb 2021 19:54:49 +0000 Subject: Remove vestiges of uploads_path config (#9462) `uploads_path` was a thing that was never used; most of it was removed in #6628 but a few vestiges remained. --- changelog.d/9462.misc | 1 + docker/README.md | 1 - docker/conf/homeserver.yaml | 1 - synapse/config/repository.py | 1 - tests/utils.py | 1 - 5 files changed, 1 insertion(+), 4 deletions(-) create mode 100644 changelog.d/9462.misc (limited to 'tests') diff --git a/changelog.d/9462.misc b/changelog.d/9462.misc new file mode 100644 index 0000000000..1b245bf85d --- /dev/null +++ b/changelog.d/9462.misc @@ -0,0 +1 @@ +Remove vestiges of `uploads_path` configuration setting. diff --git a/docker/README.md b/docker/README.md index c8f27b8566..7b138df4d3 100644 --- a/docker/README.md +++ b/docker/README.md @@ -11,7 +11,6 @@ The image also does *not* provide a TURN server. By default, the image expects a single volume, located at ``/data``, that will hold: * configuration files; -* temporary files during uploads; * uploaded media and thumbnails; * the SQLite database if you do not configure postgres; * the appservices configuration. diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml index 2ed570a5d1..0dea62a87d 100644 --- a/docker/conf/homeserver.yaml +++ b/docker/conf/homeserver.yaml @@ -89,7 +89,6 @@ federation_rc_concurrent: 3 ## Files ## media_store_path: "/data/media" -uploads_path: "/data/uploads" max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}" max_image_pixels: "32M" dynamic_thumbnails: false diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 52849c3256..69d9de5a43 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -206,7 +206,6 @@ class ContentRepositoryConfig(Config): def generate_config_section(self, data_dir_path, **kwargs): media_store = os.path.join(data_dir_path, "media_store") - uploads_path = os.path.join(data_dir_path, "uploads") formatted_thumbnail_sizes = "".join( THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES diff --git a/tests/utils.py b/tests/utils.py index 4fb5098550..be80b13760 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -114,7 +114,6 @@ def default_config(name, parse=False): "server_name": name, "send_federation": False, "media_store_path": "media", - "uploads_path": "uploads", # the test signing key is just an arbitrary ed25519 key to keep the config # parser happy "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", -- cgit 1.5.1 From 292792194232e68eb68c5135bb59d037d38b870b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 24 Feb 2021 13:23:18 +0000 Subject: Clean up `ShardedWorkerHandlingConfig` (#9466) * Split ShardedWorkerHandlingConfig This is so that we have a type level understanding of when it is safe to call `get_instance(..)` (as opposed to `should_handle(..)`). * Remove special cases in ShardedWorkerHandlingConfig. `ShardedWorkerHandlingConfig` tried to handle the various different ways it was possible to configure federation senders and pushers. This led to special cases that weren't hit during testing. To fix this the handling of the different cases is moved from there and `generic_worker` into the worker config class. This allows us to have the logic in one place and allows the rest of the code to ignore the different cases. --- changelog.d/9466.bugfix | 1 + synapse/app/admin_cmd.py | 2 + synapse/app/generic_worker.py | 32 -------- synapse/config/_base.py | 36 ++++++--- synapse/config/_base.pyi | 2 + synapse/config/push.py | 5 +- synapse/config/server.py | 1 - synapse/config/workers.py | 93 +++++++++++++++++++++-- synapse/push/pusherpool.py | 4 +- synapse/server.py | 7 +- tests/replication/tcp/streams/test_federation.py | 2 +- tests/replication/test_federation_ack.py | 2 +- tests/replication/test_federation_sender_shard.py | 2 +- tests/replication/test_pusher_shard.py | 2 +- 14 files changed, 128 insertions(+), 63 deletions(-) create mode 100644 changelog.d/9466.bugfix (limited to 'tests') diff --git a/changelog.d/9466.bugfix b/changelog.d/9466.bugfix new file mode 100644 index 0000000000..2ab4f315c1 --- /dev/null +++ b/changelog.d/9466.bugfix @@ -0,0 +1 @@ +Fix deleting pushers when using sharded pushers. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index b4bd4d8e7a..9f99651aa2 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -210,7 +210,9 @@ def start(config_options): config.update_user_directory = False config.run_background_tasks = False config.start_pushers = False + config.pusher_shard_config.instances = [] config.send_federation = False + config.federation_shard_config.instances = [] synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 6ed405e66b..dc0d3eb725 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -919,22 +919,6 @@ def start(config_options): # For other worker types we force this to off. config.appservice.notify_appservices = False - if config.worker_app == "synapse.app.pusher": - if config.server.start_pushers: - sys.stderr.write( - "\nThe pushers must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``start_pushers: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.server.start_pushers = True - else: - # For other worker types we force this to off. - config.server.start_pushers = False - if config.worker_app == "synapse.app.user_dir": if config.server.update_user_directory: sys.stderr.write( @@ -951,22 +935,6 @@ def start(config_options): # For other worker types we force this to off. config.server.update_user_directory = False - if config.worker_app == "synapse.app.federation_sender": - if config.worker.send_federation: - sys.stderr.write( - "\nThe send_federation must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``send_federation: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.worker.send_federation = True - else: - # For other worker types we force this to off. - config.worker.send_federation = False - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts hs = GenericWorkerServer( diff --git a/synapse/config/_base.py b/synapse/config/_base.py index e89decda34..4026966711 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -844,22 +844,23 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key.""" - # If multiple instances are not defined we always return true - if not self.instances or len(self.instances) == 1: - return True + # If no instances are defined we assume some other worker is handling + # this. + if not self.instances: + return False - return self.get_instance(key) == instance_name + return self._get_instance(key) == instance_name - def get_instance(self, key: str) -> str: + def _get_instance(self, key: str) -> str: """Get the instance responsible for handling the given key. - Note: For things like federation sending the config for which instance - is sending is known only to the sender instance if there is only one. - Therefore `should_handle` should be used where possible. + Note: For federation sending and pushers the config for which instance + is sending is known only to the sender instance, so we don't expose this + method by default. """ if not self.instances: - return "master" + raise Exception("Unknown worker") if len(self.instances) == 1: return self.instances[0] @@ -876,4 +877,21 @@ class ShardedWorkerHandlingConfig: return self.instances[remainder] +@attr.s +class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): + """A version of `ShardedWorkerHandlingConfig` that is used for config + options where all instances know which instances are responsible for the + sharded work. + """ + + def __attrs_post_init__(self): + # We require that `self.instances` is non-empty. + if not self.instances: + raise Exception("Got empty list of instances for shard config") + + def get_instance(self, key: str) -> str: + """Get the instance responsible for handling the given key.""" + return self._get_instance(key) + + __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 70025b5d60..db16c86f50 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -149,4 +149,6 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... + +class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/push.py b/synapse/config/push.py index 3adbfb73e6..7831a2ef79 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ShardedWorkerHandlingConfig +from ._base import Config class PushConfig(Config): @@ -27,9 +27,6 @@ class PushConfig(Config): "group_unread_count_by_room", True ) - pusher_instances = config.get("pusher_instances") or [] - self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) - # There was a a 'redact_content' setting but mistakenly read from the # 'email'section'. Check for the flag in the 'push' section, and log, # but do not honour it to avoid nasty surprises when people upgrade. diff --git a/synapse/config/server.py b/synapse/config/server.py index 0bfd4398e2..2afca36e7d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -397,7 +397,6 @@ class ServerConfig(Config): if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": self.public_baseurl += "/" - self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, # for testing. The value defines the number of milliseconds to pause before diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 7a0ca16da8..ac92375a85 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -17,9 +17,28 @@ from typing import List, Union import attr -from ._base import Config, ConfigError, ShardedWorkerHandlingConfig +from ._base import ( + Config, + ConfigError, + RoutableShardedWorkerHandlingConfig, + ShardedWorkerHandlingConfig, +) from .server import ListenerConfig, parse_listener_def +_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR = """ +The send_federation config option must be disabled in the main +synapse process before they can be run in a separate worker. + +Please add ``send_federation: false`` to the main config +""" + +_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR = """ +The start_pushers config option must be disabled in the main +synapse process before they can be run in a separate worker. + +Please add ``start_pushers: false`` to the main config +""" + def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: """Helper for allowing parsing a string or list of strings to a config @@ -103,6 +122,7 @@ class WorkerConfig(Config): self.worker_replication_secret = config.get("worker_replication_secret", None) self.worker_name = config.get("worker_name", self.worker_app) + self.instance_name = self.worker_name or "master" self.worker_main_http_uri = config.get("worker_main_http_uri", None) @@ -118,12 +138,41 @@ class WorkerConfig(Config): ) ) - # Whether to send federation traffic out in this process. This only - # applies to some federation traffic, and so shouldn't be used to - # "disable" federation - self.send_federation = config.get("send_federation", True) + # Handle federation sender configuration. + # + # There are two ways of configuring which instances handle federation + # sending: + # 1. The old way where "send_federation" is set to false and running a + # `synapse.app.federation_sender` worker app. + # 2. Specifying the workers sending federation in + # `federation_sender_instances`. + # + + send_federation = config.get("send_federation", True) + + federation_sender_instances = config.get("federation_sender_instances") + if federation_sender_instances is None: + # Default to an empty list, which means "another, unknown, worker is + # responsible for it". + federation_sender_instances = [] - federation_sender_instances = config.get("federation_sender_instances") or [] + # If no federation sender instances are set we check if + # `send_federation` is set, which means use master + if send_federation: + federation_sender_instances = ["master"] + + if self.worker_app == "synapse.app.federation_sender": + if send_federation: + # If we're running federation senders, and not using + # `federation_sender_instances`, then we should have + # explicitly set `send_federation` to false. + raise ConfigError( + _FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR + ) + + federation_sender_instances = [self.worker_name] + + self.send_federation = self.instance_name in federation_sender_instances self.federation_shard_config = ShardedWorkerHandlingConfig( federation_sender_instances ) @@ -164,7 +213,37 @@ class WorkerConfig(Config): "Must only specify one instance to handle `receipts` messages." ) - self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) + if len(self.writers.events) == 0: + raise ConfigError("Must specify at least one instance to handle `events`.") + + self.events_shard_config = RoutableShardedWorkerHandlingConfig( + self.writers.events + ) + + # Handle sharded push + start_pushers = config.get("start_pushers", True) + pusher_instances = config.get("pusher_instances") + if pusher_instances is None: + # Default to an empty list, which means "another, unknown, worker is + # responsible for it". + pusher_instances = [] + + # If no pushers instances are set we check if `start_pushers` is + # set, which means use master + if start_pushers: + pusher_instances = ["master"] + + if self.worker_app == "synapse.app.pusher": + if start_pushers: + # If we're running pushers, and not using + # `pusher_instances`, then we should have explicitly set + # `start_pushers` to false. + raise ConfigError(_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR) + + pusher_instances = [self.instance_name] + + self.start_pushers = self.instance_name in pusher_instances + self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) # Whether this worker should run background tasks or not. # diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 3936bf8784..21f14f05f0 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -59,7 +59,6 @@ class PusherPool: def __init__(self, hs: "HomeServer"): self.hs = hs self.pusher_factory = PusherFactory(hs) - self._should_start_pushers = hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() @@ -68,6 +67,9 @@ class PusherPool: # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() + self._should_start_pushers = ( + self._instance_name in self._pusher_shard_config.instances + ) # We can only delete pushers on master. self._remove_pusher_client = None diff --git a/synapse/server.py b/synapse/server.py index 5de8782000..4b9ec7f0ae 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -248,7 +248,7 @@ class HomeServer(metaclass=abc.ABCMeta): self.start_time = None # type: Optional[int] self._instance_id = random_string(5) - self._instance_name = config.worker_name or "master" + self._instance_name = config.worker.instance_name self.version_string = version_string @@ -760,7 +760,4 @@ class HomeServer(metaclass=abc.ABCMeta): def should_send_federation(self) -> bool: "Should this server be sending federation traffic directly?" - return self.config.send_federation and ( - not self.config.worker_app - or self.config.worker_app == "synapse.app.federation_sender" - ) + return self.config.send_federation diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py index 2babea4e3e..aa4bf1c7e3 100644 --- a/tests/replication/tcp/streams/test_federation.py +++ b/tests/replication/tcp/streams/test_federation.py @@ -24,7 +24,7 @@ class FederationStreamTestCase(BaseStreamTestCase): # enable federation sending on the worker config = super()._get_worker_hs_config() # TODO: make it so we don't need both of these - config["send_federation"] = True + config["send_federation"] = False config["worker_app"] = "synapse.app.federation_sender" return config diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 1853667558..f235f1bd83 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -27,7 +27,7 @@ class FederationAckTestCase(HomeserverTestCase): def default_config(self) -> dict: config = super().default_config() config["worker_app"] = "synapse.app.federation_sender" - config["send_federation"] = True + config["send_federation"] = False return config def make_homeserver(self, reactor, clock): diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index fffdb742c8..2f2d117858 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -49,7 +49,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.federation_sender", - {"send_federation": True}, + {"send_federation": False}, federation_http_client=mock_client, ) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index f118fe32af..ab2988a6ba 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -95,7 +95,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.pusher", - {"start_pushers": True}, + {"start_pushers": False}, proxied_blacklisted_http_client=http_client_mock, ) -- cgit 1.5.1 From 2566dc57ce69b1a141014cc7231f84ea6d3f3ea7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Feb 2021 15:35:14 +0000 Subject: Test that we require validated email for email pushers (#9496) --- changelog.d/9496.misc | 1 + synapse/push/pusherpool.py | 6 ++++++ tests/push/test_email.py | 34 ++++++++++++++++++++++++++++++++-- 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 changelog.d/9496.misc (limited to 'tests') diff --git a/changelog.d/9496.misc b/changelog.d/9496.misc new file mode 100644 index 0000000000..d5866c56f7 --- /dev/null +++ b/changelog.d/9496.misc @@ -0,0 +1 @@ +Test that we require validated email for email pushers. diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 21f14f05f0..4c7f5fecee 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, Optional from prometheus_client import Gauge +from synapse.api.errors import Codes, SynapseError from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -113,6 +114,11 @@ class PusherPool: The newly created pusher. """ + if kind == "email": + email_owner = await self.store.get_user_id_by_threepid("email", pushkey) + if email_owner != user_id: + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) + time_now_msec = self.clock.time_msec() # create the pusher setting last_stream_ordering to the current maximum diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 22f452ec24..941cf42429 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -21,6 +21,7 @@ import pkg_resources from twisted.internet.defer import Deferred import synapse.rest.admin +from synapse.api.errors import Codes, SynapseError from synapse.rest.client.v1 import login, room from tests.unittest import HomeserverTestCase @@ -100,12 +101,19 @@ class EmailPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token) ) - token_id = user_tuple.token_id + self.token_id = user_tuple.token_id + + # We need to add email to account before we can create a pusher. + self.get_success( + hs.get_datastore().user_add_threepid( + self.user_id, "email", "a@example.com", 0, 0 + ) + ) self.pusher = self.get_success( self.hs.get_pusherpool().add_pusher( user_id=self.user_id, - access_token=token_id, + access_token=self.token_id, kind="email", app_id="m.email", app_display_name="Email Notifications", @@ -116,6 +124,28 @@ class EmailPusherTests(HomeserverTestCase): ) ) + def test_need_validated_email(self): + """Test that we can only add an email pusher if the user has validated + their email. + """ + with self.assertRaises(SynapseError) as cm: + self.get_success_or_raise( + self.hs.get_pusherpool().add_pusher( + user_id=self.user_id, + access_token=self.token_id, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name="b@example.com", + pushkey="b@example.com", + lang=None, + data={}, + ) + ) + + self.assertEqual(400, cm.exception.code) + self.assertEqual(Codes.THREEPID_NOT_FOUND, cm.exception.errcode) + def test_simple_sends_email(self): # Create a simple room with two users room = self.helper.create_room_as(self.user_id, tok=self.access_token) -- cgit 1.5.1 From 15090de85075c9d7d54479b4bfd79057de64059b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 26 Feb 2021 14:02:06 +0000 Subject: SSO: redirect to public URL before setting cookies (#9436) ... otherwise, we don't get the cookie back. --- changelog.d/9436.bugfix | 1 + synapse/http/__init__.py | 37 +++++++++++++++++++- synapse/rest/client/v1/login.py | 28 +++++++++++++++ tests/rest/client/v1/test_login.py | 61 ++++++++++++++++++++------------- tests/rest/client/v1/utils.py | 19 +++++++++- tests/rest/client/v2_alpha/test_auth.py | 6 +++- tests/server.py | 6 +++- 7 files changed, 130 insertions(+), 28 deletions(-) create mode 100644 changelog.d/9436.bugfix (limited to 'tests') diff --git a/changelog.d/9436.bugfix b/changelog.d/9436.bugfix new file mode 100644 index 0000000000..a530516eed --- /dev/null +++ b/changelog.d/9436.bugfix @@ -0,0 +1 @@ +Fix a bug in single sign-on which could cause a "No session cookie found" error. diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index c658862fe6..142b007d01 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import re +from typing import Union -from twisted.internet import task +from twisted.internet import address, task from twisted.web.client import FileBodyProducer from twisted.web.iweb import IRequest @@ -53,6 +54,40 @@ class QuieterFileBodyProducer(FileBodyProducer): pass +def get_request_uri(request: IRequest) -> bytes: + """Return the full URI that was requested by the client""" + return b"%s://%s%s" % ( + b"https" if request.isSecure() else b"http", + _get_requested_host(request), + # despite its name, "request.uri" is only the path and query-string. + request.uri, + ) + + +def _get_requested_host(request: IRequest) -> bytes: + hostname = request.getHeader(b"host") + if hostname: + return hostname + + # no Host header, use the address/port that the request arrived on + host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address] + + hostname = host.host.encode("ascii") + + if request.isSecure() and host.port == 443: + # default port for https + return hostname + + if not request.isSecure() and host.port == 80: + # default port for http + return hostname + + return b"%s:%i" % ( + hostname, + host.port, + ) + + def get_request_user_agent(request: IRequest, default: str = "") -> str: """Return the last User-Agent header, or the given default.""" # There could be raw utf-8 bytes in the User-Agent header. diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 6e2fbedd99..925edfc402 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -20,6 +20,7 @@ from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService from synapse.handlers.sso import SsoIdentityProvider +from synapse.http import get_request_uri from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, @@ -354,6 +355,7 @@ class SsoRedirectServlet(RestServlet): hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() self._msc2858_enabled = hs.config.experimental.msc2858_enabled + self._public_baseurl = hs.config.public_baseurl def register(self, http_server: HttpServer) -> None: super().register(http_server) @@ -373,6 +375,32 @@ class SsoRedirectServlet(RestServlet): async def on_GET( self, request: SynapseRequest, idp_id: Optional[str] = None ) -> None: + if not self._public_baseurl: + raise SynapseError(400, "SSO requires a valid public_baseurl") + + # if this isn't the expected hostname, redirect to the right one, so that we + # get our cookies back. + requested_uri = get_request_uri(request) + baseurl_bytes = self._public_baseurl.encode("utf-8") + if not requested_uri.startswith(baseurl_bytes): + # swap out the incorrect base URL for the right one. + # + # The idea here is to redirect from + # https://foo.bar/whatever/_matrix/... + # to + # https://public.baseurl/_matrix/... + # + i = requested_uri.index(b"/_matrix") + new_uri = baseurl_bytes[:-1] + requested_uri[i:] + logger.info( + "Requested URI %s is not canonical: redirecting to %s", + requested_uri.decode("utf-8", errors="replace"), + new_uri.decode("utf-8", errors="replace"), + ) + request.redirect(new_uri) + finish_request(request) + return + client_redirect_url = parse_string( request, "redirectUrl", required=True, encoding=None ) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index fb29eaed6f..744d8d0941 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,7 +15,7 @@ import time import urllib.parse -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from urllib.parse import urlencode from mock import Mock @@ -47,8 +47,14 @@ except ImportError: HAS_JWT = False -# public_base_url used in some tests -BASE_URL = "https://synapse/" +# synapse server name: used to populate public_baseurl in some tests +SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" + +# public_baseurl for some tests. It uses an http:// scheme because +# FakeChannel.isSecure() returns False, so synapse will see the requested uri as +# http://..., so using http in the public_baseurl stops Synapse trying to redirect to +# https://.... +BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) # CAS server used in some tests CAS_SERVER = "https://fake.test" @@ -480,11 +486,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker - channel = self.make_request( - "GET", - "/_matrix/client/r0/login/sso/redirect?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) + channel = self._make_sso_redirect_request(False, None) self.assertEqual(channel.code, 302, channel.result) uri = channel.headers.getRawHeaders("Location")[0] @@ -628,34 +630,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_client_idp_redirect_msc2858_disabled(self): """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" - channel = self.make_request( - "GET", - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) + channel = self._make_sso_redirect_request(True, "oidc") self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_unknown(self): """If the client tries to pick an unknown IdP, return a 404""" - channel = self.make_request( - "GET", - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) + channel = self._make_sso_redirect_request(True, "xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_oidc(self): """If the client pick a known IdP, redirect to it""" - channel = self.make_request( - "GET", - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) - + channel = self._make_sso_redirect_request(True, "oidc") self.assertEqual(channel.code, 302, channel.result) oidc_uri = channel.headers.getRawHeaders("Location")[0] oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) @@ -663,6 +652,30 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + def _make_sso_redirect_request( + self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None + ): + """Send a request to /_matrix/client/r0/login/sso/redirect + + ... or the unstable equivalent + + ... possibly specifying an IDP provider + """ + endpoint = ( + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect" + if unstable_endpoint + else "/_matrix/client/r0/login/sso/redirect" + ) + if idp_prov is not None: + endpoint += "/" + idp_prov + endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + return self.make_request( + "GET", + endpoint, + custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)], + ) + @staticmethod def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: prefix = key + " = " diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 8231a423f3..946740aa5d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -542,13 +542,30 @@ class RestHelper: if client_redirect_url: params["redirectUrl"] = client_redirect_url - # hit the redirect url (which will issue a cookie and state) + # hit the redirect url (which should redirect back to the redirect url. This + # is the easiest way of figuring out what the Host header ought to be set to + # to keep Synapse happy. channel = make_request( self.hs.get_reactor(), self.site, "GET", "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), ) + assert channel.code == 302 + + # hit the redirect url again with the right Host header, which should now issue + # a cookie and redirect to the SSO provider. + location = channel.headers.getRawHeaders("Location")[0] + parts = urllib.parse.urlsplit(location) + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + urllib.parse.urlunsplit(("", "") + parts[2:]), + custom_headers=[ + ("Host", parts[1]), + ], + ) assert channel.code == 302 channel.extract_cookies(cookies) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index c26ad824f7..9734a2159a 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -161,7 +161,11 @@ class UIAuthTests(unittest.HomeserverTestCase): def default_config(self): config = super().default_config() - config["public_baseurl"] = "https://synapse.test" + + # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns + # False, so synapse will see the requested uri as http://..., so using http in + # the public_baseurl stops Synapse trying to redirect to https. + config["public_baseurl"] = "http://synapse.test" if HAS_OIDC: # we enable OIDC as a way of testing SSO flows diff --git a/tests/server.py b/tests/server.py index d4ece5c448..939a0008ca 100644 --- a/tests/server.py +++ b/tests/server.py @@ -124,7 +124,11 @@ class FakeChannel: return address.IPv4Address("TCP", self._ip, 3423) def getHost(self): - return None + # this is called by Request.__init__ to configure Request.host. + return address.IPv4Address("TCP", "127.0.0.1", 8888) + + def isSecure(self): + return False @property def transport(self): -- cgit 1.5.1