From d3f9afd8d9db8c80b342177b9ab162c79357c431 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 19 Jul 2024 16:19:15 +0100 Subject: Add a cache on `get_rooms_for_local_user_where_membership_is` (#17460) As it gets used in sliding sync. We basically invalidate it in all the same places as `get_rooms_for_user`. Most of the changes are due to needing the arguments you pass in to be hashable (which lists aren't) --- synapse/storage/databases/main/cache.py | 6 ++++++ synapse/storage/databases/main/roommember.py | 26 +++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2d6b75e47e..26b8e1a172 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -331,6 +331,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "get_invited_rooms_for_local_user", (state_key,) ) self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,)) + self._attempt_to_invalidate_cache( + "_get_rooms_for_local_user_where_membership_is_inner", (state_key,) + ) self._attempt_to_invalidate_cache( "did_forget", @@ -393,6 +396,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None) self._attempt_to_invalidate_cache("get_invited_rooms_for_local_user", None) self._attempt_to_invalidate_cache("get_rooms_for_user", None) + self._attempt_to_invalidate_cache( + "_get_rooms_for_local_user_where_membership_is_inner", None + ) self._attempt_to_invalidate_cache("did_forget", None) self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None) self._attempt_to_invalidate_cache("get_references_for_event", None) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index f62d9f705d..640ab123f0 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -445,9 +445,11 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): if not membership_list: return [] - rooms = await self.db_pool.runInteraction( - "get_rooms_for_local_user_where_membership_is", - self._get_rooms_for_local_user_where_membership_is_txn, + # Convert membership list to frozen set as a) it needs to be hashable, + # and b) we don't care about the order. + membership_list = frozenset(membership_list) + + rooms = await self._get_rooms_for_local_user_where_membership_is_inner( user_id, membership_list, ) @@ -466,6 +468,24 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): return [room for room in rooms if room.room_id not in rooms_to_exclude] + @cached(max_entries=1000, tree=True) + async def _get_rooms_for_local_user_where_membership_is_inner( + self, + user_id: str, + membership_list: Collection[str], + ) -> Sequence[RoomsForUser]: + if not membership_list: + return [] + + rooms = await self.db_pool.runInteraction( + "get_rooms_for_local_user_where_membership_is", + self._get_rooms_for_local_user_where_membership_is_txn, + user_id, + membership_list, + ) + + return rooms + def _get_rooms_for_local_user_where_membership_is_txn( self, txn: LoggingTransaction, -- cgit 1.5.1 From dc8ddc6472ba19905b3fd0c4f4da4088223e03b0 Mon Sep 17 00:00:00 2001 From: Shay Date: Mon, 22 Jul 2024 02:33:17 -0700 Subject: Prepare for authenticated media freeze (#17433) As part of the rollout of [MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/main/proposals/3916-authentication-for-media.md) this PR adds support for designating authenticated media and ensuring that authenticated media is not served over unauthenticated endpoints. --- changelog.d/17433.feature | 1 + docs/usage/configuration/config_documentation.md | 12 ++ synapse/_scripts/synapse_port_db.py | 5 +- synapse/config/repository.py | 4 + synapse/media/media_repository.py | 39 +++- synapse/media/thumbnailer.py | 48 ++++- synapse/rest/media/download_resource.py | 3 +- synapse/rest/media/thumbnail_resource.py | 5 +- synapse/storage/databases/main/media_repository.py | 28 ++- synapse/storage/schema/__init__.py | 5 +- .../schema/main/delta/86/01_authenticate_media.sql | 15 ++ tests/rest/client/test_media.py | 209 +++++++++++++++++++++ 12 files changed, 362 insertions(+), 12 deletions(-) create mode 100644 changelog.d/17433.feature create mode 100644 synapse/storage/schema/main/delta/86/01_authenticate_media.sql (limited to 'synapse/storage/databases') diff --git a/changelog.d/17433.feature b/changelog.d/17433.feature new file mode 100644 index 0000000000..ac9b5dee69 --- /dev/null +++ b/changelog.d/17433.feature @@ -0,0 +1 @@ +Prepare for authenticated media freeze. \ No newline at end of file diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 38b24b5044..e8bc2df798 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -1863,6 +1863,18 @@ federation_rr_transactions_per_room_per_second: 40 ## Media Store Config options related to Synapse's media store. +--- +### `enable_authenticated_media` + +When set to true, all subsequent media uploads will be marked as authenticated, and will not be available over legacy +unauthenticated media endpoints (`/_matrix/media/(r0|v3|v1)/download` and `/_matrix/media/(r0|v3|v1)/thumbnail`) - requests for authenticated media over these endpoints will result in a 404. All media, including authenticated media, will be available over the authenticated media endpoints `_matrix/client/v1/media/download` and `_matrix/client/v1/media/thumbnail`. Media uploaded prior to setting this option to true will still be available over the legacy endpoints. Note if the setting is switched to false +after enabling, media marked as authenticated will be available over legacy endpoints. Defaults to false, but +this will change to true in a future Synapse release. + +Example configuration: +```yaml +enable_authenticated_media: true +``` --- ### `enable_media_repo` diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 3bb4a34938..5c6db8118f 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -119,18 +119,19 @@ BOOLEAN_COLUMNS = { "e2e_room_keys": ["is_verified"], "event_edges": ["is_state"], "events": ["processed", "outlier", "contains_url"], - "local_media_repository": ["safe_from_quarantine"], + "local_media_repository": ["safe_from_quarantine", "authenticated"], + "per_user_experimental_features": ["enabled"], "presence_list": ["accepted"], "presence_stream": ["currently_active"], "public_room_list_stream": ["visibility"], "pushers": ["enabled"], "redactions": ["have_censored"], + "remote_media_cache": ["authenticated"], "room_stats_state": ["is_federatable"], "rooms": ["is_public", "has_auth_chain_index"], "users": ["shadow_banned", "approved", "locked", "suspended"], "un_partial_stated_event_stream": ["rejection_status_changed"], "users_who_share_rooms": ["share_private"], - "per_user_experimental_features": ["enabled"], } diff --git a/synapse/config/repository.py b/synapse/config/repository.py index dc0e93ffa1..97ce6de528 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -272,6 +272,10 @@ class ContentRepositoryConfig(Config): remote_media_lifetime ) + self.enable_authenticated_media = config.get( + "enable_authenticated_media", False + ) + def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str: assert data_dir_path is not None media_store = os.path.join(data_dir_path, "media_store") diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 87c929eb20..8bc92305fe 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -430,6 +430,7 @@ class MediaRepository: media_id: str, name: Optional[str], max_timeout_ms: int, + allow_authenticated: bool = True, federation: bool = False, ) -> None: """Responds to requests for local media, if exists, or returns 404. @@ -442,6 +443,7 @@ class MediaRepository: the filename in the Content-Disposition header of the response. max_timeout_ms: the maximum number of milliseconds to wait for the media to be uploaded. + allow_authenticated: whether media marked as authenticated may be served to this request federation: whether the local media being fetched is for a federation request Returns: @@ -451,6 +453,10 @@ class MediaRepository: if not media_info: return + if self.hs.config.media.enable_authenticated_media and not allow_authenticated: + if media_info.authenticated: + raise NotFoundError() + self.mark_recently_accessed(None, media_id) media_type = media_info.media_type @@ -481,6 +487,7 @@ class MediaRepository: max_timeout_ms: int, ip_address: str, use_federation_endpoint: bool, + allow_authenticated: bool = True, ) -> None: """Respond to requests for remote media. @@ -495,6 +502,8 @@ class MediaRepository: ip_address: the IP address of the requester use_federation_endpoint: whether to request the remote media over the new federation `/download` endpoint + allow_authenticated: whether media marked as authenticated may be served to this + request Returns: Resolves once a response has successfully been written to request @@ -526,6 +535,7 @@ class MediaRepository: self.download_ratelimiter, ip_address, use_federation_endpoint, + allow_authenticated, ) # We deliberately stream the file outside the lock @@ -548,6 +558,7 @@ class MediaRepository: max_timeout_ms: int, ip_address: str, use_federation: bool, + allow_authenticated: bool, ) -> RemoteMedia: """Gets the media info associated with the remote file, downloading if necessary. @@ -560,6 +571,8 @@ class MediaRepository: ip_address: IP address of the requester use_federation: if a download is necessary, whether to request the remote file over the federation `/download` endpoint + allow_authenticated: whether media marked as authenticated may be served to this + request Returns: The media info of the file @@ -581,6 +594,7 @@ class MediaRepository: self.download_ratelimiter, ip_address, use_federation, + allow_authenticated, ) # Ensure we actually use the responder so that it releases resources @@ -598,6 +612,7 @@ class MediaRepository: download_ratelimiter: Ratelimiter, ip_address: str, use_federation_endpoint: bool, + allow_authenticated: bool, ) -> Tuple[Optional[Responder], RemoteMedia]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -619,6 +634,11 @@ class MediaRepository: """ media_info = await self.store.get_cached_remote_media(server_name, media_id) + if self.hs.config.media.enable_authenticated_media and not allow_authenticated: + # if it isn't cached then don't fetch it or if it's authenticated then don't serve it + if not media_info or media_info.authenticated: + raise NotFoundError() + # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise generate a new # one. @@ -792,6 +812,11 @@ class MediaRepository: logger.info("Stored remote media in file %r", fname) + if self.hs.config.media.enable_authenticated_media: + authenticated = True + else: + authenticated = False + return RemoteMedia( media_origin=server_name, media_id=media_id, @@ -802,6 +827,7 @@ class MediaRepository: filesystem_id=file_id, last_access_ts=time_now_ms, quarantined_by=None, + authenticated=authenticated, ) async def _federation_download_remote_file( @@ -915,6 +941,11 @@ class MediaRepository: logger.debug("Stored remote media in file %r", fname) + if self.hs.config.media.enable_authenticated_media: + authenticated = True + else: + authenticated = False + return RemoteMedia( media_origin=server_name, media_id=media_id, @@ -925,6 +956,7 @@ class MediaRepository: filesystem_id=file_id, last_access_ts=time_now_ms, quarantined_by=None, + authenticated=authenticated, ) def _get_thumbnail_requirements( @@ -1030,7 +1062,12 @@ class MediaRepository: t_len = os.path.getsize(output_path) await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len + media_id, + t_width, + t_height, + t_type, + t_method, + t_len, ) return output_path diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index 413a720e40..ef6aa8ccf5 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -26,7 +26,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Type from PIL import Image -from synapse.api.errors import Codes, SynapseError, cs_error +from synapse.api.errors import Codes, NotFoundError, SynapseError, cs_error from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP from synapse.http.server import respond_with_json from synapse.http.site import SynapseRequest @@ -274,6 +274,7 @@ class ThumbnailProvider: m_type: str, max_timeout_ms: int, for_federation: bool, + allow_authenticated: bool = True, ) -> None: media_info = await self.media_repo.get_local_media_info( request, media_id, max_timeout_ms @@ -281,6 +282,12 @@ class ThumbnailProvider: if not media_info: return + # if the media the thumbnail is generated from is authenticated, don't serve the + # thumbnail over an unauthenticated endpoint + if self.hs.config.media.enable_authenticated_media and not allow_authenticated: + if media_info.authenticated: + raise NotFoundError() + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) await self._select_and_respond_with_thumbnail( request, @@ -307,14 +314,20 @@ class ThumbnailProvider: desired_type: str, max_timeout_ms: int, for_federation: bool, + allow_authenticated: bool = True, ) -> None: media_info = await self.media_repo.get_local_media_info( request, media_id, max_timeout_ms ) - if not media_info: return + # if the media the thumbnail is generated from is authenticated, don't serve the + # thumbnail over an unauthenticated endpoint + if self.hs.config.media.enable_authenticated_media and not allow_authenticated: + if media_info.authenticated: + raise NotFoundError() + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: t_w = info.width == desired_width @@ -381,14 +394,27 @@ class ThumbnailProvider: max_timeout_ms: int, ip_address: str, use_federation: bool, + allow_authenticated: bool = True, ) -> None: media_info = await self.media_repo.get_remote_media_info( - server_name, media_id, max_timeout_ms, ip_address, use_federation + server_name, + media_id, + max_timeout_ms, + ip_address, + use_federation, + allow_authenticated, ) if not media_info: respond_404(request) return + # if the media the thumbnail is generated from is authenticated, don't serve the + # thumbnail over an unauthenticated endpoint + if self.hs.config.media.enable_authenticated_media and not allow_authenticated: + if media_info.authenticated: + respond_404(request) + return + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -446,16 +472,28 @@ class ThumbnailProvider: max_timeout_ms: int, ip_address: str, use_federation: bool, + allow_authenticated: bool = True, ) -> None: # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. media_info = await self.media_repo.get_remote_media_info( - server_name, media_id, max_timeout_ms, ip_address, use_federation + server_name, + media_id, + max_timeout_ms, + ip_address, + use_federation, + allow_authenticated, ) if not media_info: return + # if the media the thumbnail is generated from is authenticated, don't serve the + # thumbnail over an unauthenticated endpoint + if self.hs.config.media.enable_authenticated_media and not allow_authenticated: + if media_info.authenticated: + raise NotFoundError() + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -485,8 +523,8 @@ class ThumbnailProvider: file_id: str, url_cache: bool, for_federation: bool, - server_name: Optional[str] = None, media_info: Optional[LocalMedia] = None, + server_name: Optional[str] = None, ) -> None: """ Respond to a request with an appropriate thumbnail from the previously generated thumbnails. diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py index c32c626905..3c3f703667 100644 --- a/synapse/rest/media/download_resource.py +++ b/synapse/rest/media/download_resource.py @@ -84,7 +84,7 @@ class DownloadResource(RestServlet): if self._is_mine_server_name(server_name): await self.media_repo.get_local_media( - request, media_id, file_name, max_timeout_ms + request, media_id, file_name, max_timeout_ms, allow_authenticated=False ) else: allow_remote = parse_boolean(request, "allow_remote", default=True) @@ -106,4 +106,5 @@ class DownloadResource(RestServlet): max_timeout_ms, ip_address, False, + allow_authenticated=False, ) diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py index 70354aa439..536fea4c32 100644 --- a/synapse/rest/media/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py @@ -96,6 +96,7 @@ class ThumbnailResource(RestServlet): m_type, max_timeout_ms, False, + allow_authenticated=False, ) else: await self.thumbnail_provider.respond_local_thumbnail( @@ -107,6 +108,7 @@ class ThumbnailResource(RestServlet): m_type, max_timeout_ms, False, + allow_authenticated=False, ) self.media_repo.mark_recently_accessed(None, media_id) else: @@ -134,6 +136,7 @@ class ThumbnailResource(RestServlet): m_type, max_timeout_ms, ip_address, - False, + use_federation=False, + allow_authenticated=False, ) self.media_repo.mark_recently_accessed(server_name, media_id) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 6128332af8..7617fd3ad4 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -64,6 +64,7 @@ class LocalMedia: quarantined_by: Optional[str] safe_from_quarantine: bool user_id: Optional[str] + authenticated: Optional[bool] @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -77,6 +78,7 @@ class RemoteMedia: created_ts: int last_access_ts: int quarantined_by: Optional[str] + authenticated: Optional[bool] @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -218,6 +220,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "last_access_ts", "safe_from_quarantine", "user_id", + "authenticated", ), allow_none=True, desc="get_local_media", @@ -235,6 +238,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): last_access_ts=row[6], safe_from_quarantine=row[7], user_id=row[8], + authenticated=row[9], ) async def get_local_media_by_user_paginate( @@ -290,7 +294,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): last_access_ts, quarantined_by, safe_from_quarantine, - user_id + user_id, + authenticated FROM local_media_repository WHERE user_id = ? ORDER BY {order_by_column} {order}, media_id ASC @@ -314,6 +319,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): quarantined_by=row[7], safe_from_quarantine=bool(row[8]), user_id=row[9], + authenticated=row[10], ) for row in txn ] @@ -417,12 +423,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): time_now_ms: int, user_id: UserID, ) -> None: + if self.hs.config.media.enable_authenticated_media: + authenticated = True + else: + authenticated = False + await self.db_pool.simple_insert( "local_media_repository", { "media_id": media_id, "created_ts": time_now_ms, "user_id": user_id.to_string(), + "authenticated": authenticated, }, desc="store_local_media_id", ) @@ -438,6 +450,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id: UserID, url_cache: Optional[str] = None, ) -> None: + if self.hs.config.media.enable_authenticated_media: + authenticated = True + else: + authenticated = False + await self.db_pool.simple_insert( "local_media_repository", { @@ -448,6 +465,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "media_length": media_length, "user_id": user_id.to_string(), "url_cache": url_cache, + "authenticated": authenticated, }, desc="store_local_media", ) @@ -638,6 +656,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "filesystem_id", "last_access_ts", "quarantined_by", + "authenticated", ), allow_none=True, desc="get_cached_remote_media", @@ -654,6 +673,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): filesystem_id=row[4], last_access_ts=row[5], quarantined_by=row[6], + authenticated=row[7], ) async def store_cached_remote_media( @@ -666,6 +686,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name: Optional[str], filesystem_id: str, ) -> None: + if self.hs.config.media.enable_authenticated_media: + authenticated = True + else: + authenticated = False + await self.db_pool.simple_insert( "remote_media_cache", { @@ -677,6 +702,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "upload_name": upload_name, "filesystem_id": filesystem_id, "last_access_ts": time_now_ms, + "authenticated": authenticated, }, desc="store_cached_remote_media", ) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 0dc5d24249..581d00346b 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -19,7 +19,7 @@ # # -SCHEMA_VERSION = 85 # remember to update the list below when updating +SCHEMA_VERSION = 86 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -139,6 +139,9 @@ Changes in SCHEMA_VERSION = 84 Changes in SCHEMA_VERSION = 85 - Add a column `suspended` to the `users` table + +Changes in SCHEMA_VERSION = 86 + - Add a column `authenticated` to the tables `local_media_repository` and `remote_media_cache` """ diff --git a/synapse/storage/schema/main/delta/86/01_authenticate_media.sql b/synapse/storage/schema/main/delta/86/01_authenticate_media.sql new file mode 100644 index 0000000000..c1ac01ae95 --- /dev/null +++ b/synapse/storage/schema/main/delta/86/01_authenticate_media.sql @@ -0,0 +1,15 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +ALTER TABLE remote_media_cache ADD COLUMN authenticated BOOLEAN DEFAULT FALSE NOT NULL; +ALTER TABLE local_media_repository ADD COLUMN authenticated BOOLEAN DEFAULT FALSE NOT NULL; diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 466c5a0b70..30b6d31d0a 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -43,6 +43,7 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor from twisted.web.http_headers import Headers from twisted.web.iweb import UNKNOWN_LENGTH, IResponse +from twisted.web.resource import Resource from synapse.api.errors import HttpResponseException from synapse.api.ratelimiting import Ratelimiter @@ -2466,3 +2467,211 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): server_name=None, ) ) + + +configs = [ + {"extra_config": {"dynamic_thumbnails": True}}, + {"extra_config": {"dynamic_thumbnails": False}}, +] + + +@parameterized_class(configs) +class AuthenticatedMediaTestCase(unittest.HomeserverTestCase): + extra_config: Dict[str, Any] + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.clock = clock + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path + config["enable_authenticated_media"] = True + + provider_config = { + "module": "synapse.media.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + + config["media_storage_providers"] = [provider_config] + config.update(self.extra_config) + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.repo = hs.get_media_repository() + self.client = hs.get_federation_http_client() + self.store = hs.get_datastores().main + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def test_authenticated_media(self) -> None: + # upload some local media with authentication on + channel = self.make_request( + "POST", + "_matrix/media/v3/upload?filename=test_png_upload", + SMALL_PNG, + self.tok, + shorthand=False, + content_type=b"image/png", + custom_headers=[("Content-Length", str(67))], + ) + self.assertEqual(channel.code, 200) + res = channel.json_body.get("content_uri") + assert res is not None + uri = res.split("mxc://")[1] + + # request media over authenticated endpoint, should be found + channel2 = self.make_request( + "GET", + f"_matrix/client/v1/media/download/{uri}", + access_token=self.tok, + shorthand=False, + ) + self.assertEqual(channel2.code, 200) + + # request same media over unauthenticated media, should raise 404 not found + channel3 = self.make_request( + "GET", f"_matrix/media/v3/download/{uri}", shorthand=False + ) + self.assertEqual(channel3.code, 404) + + # check thumbnails as well + params = "?width=32&height=32&method=crop" + channel4 = self.make_request( + "GET", + f"/_matrix/client/v1/media/thumbnail/{uri}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel4.code, 200) + + params = "?width=32&height=32&method=crop" + channel5 = self.make_request( + "GET", + f"/_matrix/media/r0/thumbnail/{uri}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel5.code, 404) + + # Inject a piece of remote media. + file_id = "abcdefg12345" + file_info = FileInfo(server_name="lonelyIsland", file_id=file_id) + + media_storage = self.hs.get_media_repository().media_storage + + ctx = media_storage.store_into_file(file_info) + (f, fname) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) + + # we write the authenticated status when storing media, so this should pick up + # config and authenticate the media + self.get_success( + self.store.store_cached_remote_media( + origin="lonelyIsland", + media_id="52", + media_type="image/png", + media_length=1, + time_now_ms=self.clock.time_msec(), + upload_name="remote_test.png", + filesystem_id=file_id, + ) + ) + + # ensure we have thumbnails for the non-dynamic code path + if self.extra_config == {"dynamic_thumbnails": False}: + self.get_success( + self.repo._generate_thumbnails( + "lonelyIsland", "52", file_id, "image/png" + ) + ) + + channel6 = self.make_request( + "GET", + "_matrix/client/v1/media/download/lonelyIsland/52", + access_token=self.tok, + shorthand=False, + ) + self.assertEqual(channel6.code, 200) + + channel7 = self.make_request( + "GET", f"_matrix/media/v3/download/{uri}", shorthand=False + ) + self.assertEqual(channel7.code, 404) + + params = "?width=32&height=32&method=crop" + channel8 = self.make_request( + "GET", + f"/_matrix/client/v1/media/thumbnail/lonelyIsland/52{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel8.code, 200) + + channel9 = self.make_request( + "GET", + f"/_matrix/media/r0/thumbnail/lonelyIsland/52{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel9.code, 404) + + # Inject a piece of local media that isn't authenticated + file_id = "abcdefg123456" + file_info = FileInfo(None, file_id=file_id) + + ctx = media_storage.store_into_file(file_info) + (f, fname) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) + + self.get_success( + self.store.db_pool.simple_insert( + "local_media_repository", + { + "media_id": "abcdefg123456", + "media_type": "image/png", + "created_ts": self.clock.time_msec(), + "upload_name": "test_local", + "media_length": 1, + "user_id": "someone", + "url_cache": None, + "authenticated": False, + }, + desc="store_local_media", + ) + ) + + # check that unauthenticated media is still available over both endpoints + channel9 = self.make_request( + "GET", + "/_matrix/client/v1/media/download/test/abcdefg123456", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel9.code, 200) + + channel10 = self.make_request( + "GET", + "/_matrix/media/r0/download/test/abcdefg123456", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel10.code, 200) -- cgit 1.5.1 From d225b6b3ebea419bdf0e6c0f1476544053f2dcbf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Jul 2024 14:03:14 +0100 Subject: Speed up SS room sorting (#17468) We do this by bulk fetching the latest stream ordering. --------- Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/17468.misc | 1 + synapse/handlers/sliding_sync.py | 43 ++++--- synapse/storage/databases/main/event_federation.py | 5 + synapse/storage/databases/main/stream.py | 123 ++++++++++++++++++++- synapse/util/caches/stream_change_cache.py | 12 +- tests/util/test_stream_change_cache.py | 4 +- 6 files changed, 159 insertions(+), 29 deletions(-) create mode 100644 changelog.d/17468.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/17468.misc b/changelog.d/17468.misc new file mode 100644 index 0000000000..d908776204 --- /dev/null +++ b/changelog.d/17468.misc @@ -0,0 +1 @@ +Speed up sorting of the room list in sliding sync. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 886d7c7159..554ab59bf3 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -1230,34 +1230,33 @@ class SlidingSyncHandler: # Assemble a map of room ID to the `stream_ordering` of the last activity that the # user should see in the room (<= `to_token`) last_activity_in_room_map: Dict[str, int] = {} - for room_id, room_for_user in sync_room_map.items(): - # If they are fully-joined to the room, let's find the latest activity - # at/before the `to_token`. - if room_for_user.membership == Membership.JOIN: - last_event_result = ( - await self.store.get_last_event_pos_in_room_before_stream_ordering( - room_id, to_token.room_key - ) - ) - - # If the room has no events at/before the `to_token`, this is probably a - # mistake in the code that generates the `sync_room_map` since that should - # only give us rooms that the user had membership in during the token range. - assert last_event_result is not None - _, event_pos = last_event_result - - last_activity_in_room_map[room_id] = event_pos.stream - else: - # Otherwise, if the user has left/been invited/knocked/been banned from - # a room, they shouldn't see anything past that point. + for room_id, room_for_user in sync_room_map.items(): + if room_for_user.membership != Membership.JOIN: + # If the user has left/been invited/knocked/been banned from a + # room, they shouldn't see anything past that point. # - # FIXME: It's possible that people should see beyond this point in - # invited/knocked cases if for example the room has + # FIXME: It's possible that people should see beyond this point + # in invited/knocked cases if for example the room has # `invite`/`world_readable` history visibility, see # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 last_activity_in_room_map[room_id] = room_for_user.event_pos.stream + # For fully-joined rooms, we find the latest activity at/before the + # `to_token`. + joined_room_positions = ( + await self.store.bulk_get_last_event_pos_in_room_before_stream_ordering( + [ + room_id + for room_id, room_for_user in sync_room_map.items() + if room_for_user.membership == Membership.JOIN + ], + to_token.room_key, + ) + ) + + last_activity_in_room_map.update(joined_room_positions) + return sorted( sync_room_map.values(), # Sort by the last activity (stream_ordering) in the room diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 24abab4a23..715846865b 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1313,6 +1313,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # We want to make the cache more effective, so we clamp to the last # change before the given ordering. last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined] + if last_change is None: + # If the room isn't in the cache we know that the last change was + # somewhere before the earliest known position of the cache, so we + # can clamp to that. + last_change = self._events_stream_cache.get_earliest_known_position() # type: ignore[attr-defined] # We don't always have a full stream_to_exterm_id table, e.g. after # the upgrade that introduced it, so we make sure we never ask for a diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index e74e0d2e91..b034361aec 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -78,10 +78,11 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.types import PersistedEventPosition, RoomStreamToken +from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -1293,6 +1294,126 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): get_last_event_pos_in_room_before_stream_ordering_txn, ) + async def bulk_get_last_event_pos_in_room_before_stream_ordering( + self, + room_ids: StrCollection, + end_token: RoomStreamToken, + ) -> Dict[str, int]: + """Bulk fetch the stream position of the latest events in the given + rooms + """ + + min_token = end_token.stream + max_token = end_token.get_max_stream_pos() + results: Dict[str, int] = {} + + # First, we check for the rooms in the stream change cache to see if we + # can just use the latest position from it. + missing_room_ids: Set[str] = set() + for room_id in room_ids: + stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id) + if stream_pos and stream_pos <= min_token: + results[room_id] = stream_pos + else: + missing_room_ids.add(room_id) + + # Next, we query the stream position from the DB. At first we fetch all + # positions less than the *max* stream pos in the token, then filter + # them down. We do this as a) this is a cheaper query, and b) the vast + # majority of rooms will have a latest token from before the min stream + # pos. + + def bulk_get_last_event_pos_txn( + txn: LoggingTransaction, batch_room_ids: StrCollection + ) -> Dict[str, int]: + # This query fetches the latest stream position in the rooms before + # the given max position. + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", batch_room_ids + ) + sql = f""" + SELECT room_id, ( + SELECT stream_ordering FROM events AS e + LEFT JOIN rejections USING (event_id) + WHERE e.room_id = r.room_id + AND stream_ordering <= ? + AND NOT outlier + AND rejection_reason IS NULL + ORDER BY stream_ordering DESC + LIMIT 1 + ) + FROM rooms AS r + WHERE {clause} + """ + txn.execute(sql, [max_token] + args) + return {row[0]: row[1] for row in txn} + + recheck_rooms: Set[str] = set() + for batched in batch_iter(missing_room_ids, 1000): + result = await self.db_pool.runInteraction( + "bulk_get_last_event_pos_in_room_before_stream_ordering", + bulk_get_last_event_pos_txn, + batched, + ) + + # Check that the stream position for the rooms are from before the + # minimum position of the token. If not then we need to fetch more + # rows. + for room_id, stream in result.items(): + if stream <= min_token: + results[room_id] = stream + else: + recheck_rooms.add(room_id) + + if not recheck_rooms: + return results + + # For the remaining rooms we need to fetch all rows between the min and + # max stream positions in the end token, and filter out the rows that + # are after the end token. + # + # This query should be fast as the range between the min and max should + # be small. + + def bulk_get_last_event_pos_recheck_txn( + txn: LoggingTransaction, batch_room_ids: StrCollection + ) -> Dict[str, int]: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", batch_room_ids + ) + sql = f""" + SELECT room_id, instance_name, stream_ordering + FROM events + WHERE ? < stream_ordering AND stream_ordering <= ? + AND NOT outlier + AND rejection_reason IS NULL + AND {clause} + ORDER BY stream_ordering ASC + """ + txn.execute(sql, [min_token, max_token] + args) + + # We take the max stream ordering that is less than the token. Since + # we ordered by stream ordering we just need to iterate through and + # take the last matching stream ordering. + txn_results: Dict[str, int] = {} + for row in txn: + room_id = row[0] + event_pos = PersistedEventPosition(row[1], row[2]) + if not event_pos.persisted_after(end_token): + txn_results[room_id] = event_pos.stream + + return txn_results + + for batched in batch_iter(recheck_rooms, 1000): + recheck_result = await self.db_pool.runInteraction( + "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck", + bulk_get_last_event_pos_recheck_txn, + batched, + ) + results.update(recheck_result) + + return results + async def get_current_room_stream_token_for_room_id( self, room_id: str ) -> RoomStreamToken: diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 91c335f85b..16fcb00206 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -327,7 +327,7 @@ class StreamChangeCache: for entity in r: self._entity_to_key.pop(entity, None) - def get_max_pos_of_last_change(self, entity: EntityType) -> int: + def get_max_pos_of_last_change(self, entity: EntityType) -> Optional[int]: """Returns an upper bound of the stream id of the last change to an entity. @@ -335,7 +335,11 @@ class StreamChangeCache: entity: The entity to check. Return: - The stream position of the latest change for the given entity or - the earliest known stream position if the entitiy is unknown. + The stream position of the latest change for the given entity, if + known """ - return self._entity_to_key.get(entity, self._earliest_known_stream_pos) + return self._entity_to_key.get(entity) + + def get_earliest_known_position(self) -> int: + """Returns the earliest position in the cache.""" + return self._earliest_known_stream_pos diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 5d38718a50..af1199ef8a 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -249,5 +249,5 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3) self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4) - # Unknown entities will return the stream start position. - self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1) + # Unknown entities will return None + self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), None) -- cgit 1.5.1