summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py3
-rw-r--r--synapse/api/errors.py2
-rw-r--r--synapse/app/generic_worker.py4
-rw-r--r--synapse/config/ratelimiting.py7
-rw-r--r--synapse/config/repository.py6
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/federation/sender/__init__.py4
-rw-r--r--synapse/handlers/admin.py2
-rw-r--r--synapse/handlers/device.py10
-rw-r--r--synapse/handlers/e2e_keys.py20
-rw-r--r--synapse/handlers/federation_event.py20
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/profile.py15
-rw-r--r--synapse/handlers/room.py6
-rw-r--r--synapse/handlers/room_list.py43
-rw-r--r--synapse/handlers/room_member.py3
-rw-r--r--synapse/handlers/room_summary.py26
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/handlers/user_directory.py8
-rw-r--r--synapse/http/matrixfederationclient.py14
-rw-r--r--synapse/logging/opentracing.py7
-rw-r--r--synapse/media/_base.py6
-rw-r--r--synapse/media/media_repository.py284
-rw-r--r--synapse/media/url_previewer.py11
-rw-r--r--synapse/metrics/_reactor_metrics.py130
-rw-r--r--synapse/module_api/__init__.py3
-rw-r--r--synapse/module_api/callbacks/third_party_event_rules_callbacks.py3
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py56
-rw-r--r--synapse/replication/tcp/handler.py72
-rw-r--r--synapse/replication/tcp/redis.py2
-rw-r--r--synapse/replication/tcp/resource.py17
-rw-r--r--synapse/replication/tcp/streams/_base.py18
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/media.py6
-rw-r--r--synapse/rest/admin/registration_tokens.py13
-rw-r--r--synapse/rest/admin/rooms.py19
-rw-r--r--synapse/rest/admin/users.py50
-rw-r--r--synapse/rest/client/account.py19
-rw-r--r--synapse/rest/client/directory.py2
-rw-r--r--synapse/rest/client/keys.py16
-rw-r--r--synapse/rest/media/create_resource.py83
-rw-r--r--synapse/rest/media/download_resource.py22
-rw-r--r--synapse/rest/media/media_repository_resource.py8
-rw-r--r--synapse/rest/media/thumbnail_resource.py81
-rw-r--r--synapse/rest/media/upload_resource.py75
-rw-r--r--synapse/storage/background_updates.py32
-rw-r--r--synapse/storage/controllers/persist_events.py252
-rw-r--r--synapse/storage/database.py77
-rw-r--r--synapse/storage/databases/__init__.py2
-rw-r--r--synapse/storage/databases/main/__init__.py52
-rw-r--r--synapse/storage/databases/main/account_data.py24
-rw-r--r--synapse/storage/databases/main/cache.py75
-rw-r--r--synapse/storage/databases/main/deviceinbox.py106
-rw-r--r--synapse/storage/databases/main/devices.py100
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py31
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py114
-rw-r--r--synapse/storage/databases/main/event_federation.py24
-rw-r--r--synapse/storage/databases/main/events.py415
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py19
-rw-r--r--synapse/storage/databases/main/events_worker.py8
-rw-r--r--synapse/storage/databases/main/keys.py17
-rw-r--r--synapse/storage/databases/main/media_repository.py249
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py2
-rw-r--r--synapse/storage/databases/main/presence.py9
-rw-r--r--synapse/storage/databases/main/profile.py28
-rw-r--r--synapse/storage/databases/main/purge_events.py33
-rw-r--r--synapse/storage/databases/main/push_rule.py94
-rw-r--r--synapse/storage/databases/main/receipts.py4
-rw-r--r--synapse/storage/databases/main/registration.py149
-rw-r--r--synapse/storage/databases/main/room.py129
-rw-r--r--synapse/storage/databases/main/roommember.py19
-rw-r--r--synapse/storage/databases/main/search.py12
-rw-r--r--synapse/storage/databases/main/stream.py30
-rw-r--r--synapse/storage/databases/main/task_scheduler.py48
-rw-r--r--synapse/storage/databases/main/transactions.py20
-rw-r--r--synapse/storage/databases/main/ui_auth.py31
-rw-r--r--synapse/storage/databases/main/user_directory.py27
-rw-r--r--synapse/storage/databases/state/bg_updates.py4
-rw-r--r--synapse/storage/engines/postgres.py3
-rw-r--r--synapse/storage/schema/__init__.py3
-rw-r--r--synapse/storage/schema/main/delta/54/delete_forward_extremities.sql2
-rw-r--r--synapse/storage/schema/main/delta/56/remove_tombstoned_rooms_from_directory.sql3
-rw-r--r--synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql3
-rw-r--r--synapse/storage/schema/main/delta/83/05_cross_signing_key_update_grant.sql15
-rw-r--r--synapse/storage/util/id_generators.py56
-rw-r--r--synapse/util/__init__.py2
-rw-r--r--synapse/util/async_helpers.py14
-rw-r--r--synapse/util/check_dependencies.py3
-rw-r--r--synapse/util/iterutils.py51
-rw-r--r--synapse/util/task_scheduler.py4
91 files changed, 2449 insertions, 1154 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py

index ef8590db65..75fe0183f6 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py
@@ -348,8 +348,7 @@ class Porter: backward_chunk = 0 already_ported = 0 else: - forward_chunk = row["forward_rowid"] - backward_chunk = row["backward_rowid"] + forward_chunk, backward_chunk = row if total_to_port is None: already_ported, total_to_port = await self._get_total_count_to_port( diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index fdb2955be8..fbd8b16ec3 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py
@@ -83,6 +83,8 @@ class Codes(str, Enum): USER_DEACTIVATED = "M_USER_DEACTIVATED" # USER_LOCKED = "M_USER_LOCKED" USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED" + NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED" + CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA" # Part of MSC3848 # https://github.com/matrix-org/matrix-spec-proposals/pull/3848 diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f7c80eee21..bcfb7a7200 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -104,8 +104,8 @@ logger = logging.getLogger("synapse.app.generic_worker") class GenericWorkerStore( - # FIXME(#3714): We need to add UserDirectoryStore as we write directly - # rather than going via the correct worker. + # FIXME(https://github.com/matrix-org/synapse/issues/3714): We need to add + # UserDirectoryStore as we write directly rather than going via the correct worker. UserDirectoryStore, StatsStore, UIAuthWorkerStore, diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4efbaeac0d..b1fcaf71a3 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py
@@ -204,3 +204,10 @@ class RatelimitConfig(Config): "rc_third_party_invite", defaults={"per_second": 0.0025, "burst_count": 5}, ) + + # Ratelimit create media requests: + self.rc_media_create = RatelimitSettings.parse( + config, + "rc_media_create", + defaults={"per_second": 10, "burst_count": 50}, + ) diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index f6cfdd3e04..839c026d70 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py
@@ -141,6 +141,12 @@ class ContentRepositoryConfig(Config): "prevent_media_downloads_from", [] ) + self.unused_expiration_time = self.parse_duration( + config.get("unused_expiration_time", "24h") + ) + + self.max_pending_media_uploads = config.get("max_pending_media_uploads", 5) + self.media_store_path = self.ensure_directory( config.get("media_store_path", "media_store") ) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 8e3064c7e7..2bb2c64ebe 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -84,7 +84,7 @@ from synapse.replication.http.federation import ( from synapse.storage.databases.main.lock import Lock from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary -from synapse.types import JsonDict, StateMap, get_domain_from_id, UserID +from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id from synapse.util import unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 7980d1a322..948fde6658 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py
@@ -581,14 +581,14 @@ class FederationSender(AbstractFederationSender): "get_joined_hosts", str(sg) ) if destinations is None: - # Add logging to help track down #13444 + # Add logging to help track down https://github.com/matrix-org/synapse/issues/13444 logger.info( "Unexpectedly did not have cached destinations for %s / %s", sg, event.event_id, ) else: - # Add logging to help track down #13444 + # Add logging to help track down https://github.com/matrix-org/synapse/issues/13444 logger.info( "Unexpectedly did not have cached prev group for %s", event.event_id, diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 2c2baeac67..d06f8e3296 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py
@@ -283,7 +283,7 @@ class AdminHandler: start, limit, user_id ) for media in media_ids: - writer.write_media_id(media["media_id"], media) + writer.write_media_id(media.media_id, attr.asdict(media)) logger.info( "[%s] Written %d media_ids of %s", diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 93472d0117..98e6e42563 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -383,7 +383,7 @@ class DeviceWorkerHandler: ) DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000 - DEVICE_MSGS_DELETE_SLEEP_MS = 1000 + DEVICE_MSGS_DELETE_SLEEP_MS = 100 async def _delete_device_messages( self, @@ -396,15 +396,17 @@ class DeviceWorkerHandler: up_to_stream_id = task.params["up_to_stream_id"] # Delete the messages in batches to avoid too much DB load. + from_stream_id = None while True: - res = await self.store.delete_messages_for_device( + from_stream_id, _ = await self.store.delete_messages_for_device_between( user_id=user_id, device_id=device_id, - up_to_stream_id=up_to_stream_id, + from_stream_id=from_stream_id, + to_stream_id=up_to_stream_id, limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT, ) - if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT: + if from_stream_id is None: return TaskStatus.COMPLETE, None, None await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d06524495f..70fa931d17 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -1450,19 +1450,25 @@ class E2eKeysHandler: return desired_key_data - async def is_cross_signing_set_up_for_user(self, user_id: str) -> bool: + async def check_cross_signing_setup(self, user_id: str) -> Tuple[bool, bool]: """Checks if the user has cross-signing set up Args: user_id: The user to check - Returns: - True if the user has cross-signing set up, False otherwise + Returns: a 2-tuple of booleans + - whether the user has cross-signing set up, and + - whether the user's master cross-signing key may be replaced without UIA. """ - existing_master_key = await self.store.get_e2e_cross_signing_key( - user_id, "master" - ) - return existing_master_key is not None + ( + exists, + ts_replacable_without_uia_before, + ) = await self.store.get_master_cross_signing_key_updatable_before(user_id) + + if ts_replacable_without_uia_before is None: + return exists, False + else: + return exists, self.clock.time_msec() < ts_replacable_without_uia_before def _check_cross_signing_key( diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 0cc8e990d9..f4c17894aa 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py
@@ -88,7 +88,7 @@ from synapse.types import ( ) from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer, concurrently_execute -from synapse.util.iterutils import batch_iter, partition +from synapse.util.iterutils import batch_iter, partition, sorted_topologically_batched from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr @@ -748,7 +748,7 @@ class FederationEventHandler: # fetching fresh state for the room if the missing event # can't be found, which slightly reduces our security. # it may also increase our DAG extremity count for the room, - # causing additional state resolution? See #1760. + # causing additional state resolution? See https://github.com/matrix-org/synapse/issues/1760. # However, fetching state doesn't hold the linearizer lock # apparently. # @@ -1669,14 +1669,13 @@ class FederationEventHandler: # XXX: it might be possible to kick this process off in parallel with fetching # the events. - while event_map: - # build a list of events whose auth events are not in the queue. - roots = tuple( - ev - for ev in event_map.values() - if not any(aid in event_map for aid in ev.auth_event_ids()) - ) + # We need to persist an event's auth events before the event. + auth_graph = { + ev: [event_map[e_id] for e_id in ev.auth_event_ids() if e_id in event_map] + for ev in event_map.values() + } + for roots in sorted_topologically_batched(event_map.values(), auth_graph): if not roots: # if *none* of the remaining events are ready, that means # we have a loop. This either means a bug in our logic, or that @@ -1698,9 +1697,6 @@ class FederationEventHandler: await self._auth_and_persist_outliers_inner(room_id, roots) - for ev in roots: - del event_map[ev.event_id] - async def _auth_and_persist_outliers_inner( self, room_id: str, fetched_events: Collection[EventBase] ) -> None: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 202beee738..4137fd50b1 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -1816,7 +1816,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # the same token repeatedly. # # Hence this guard where we just return nothing so that the sync - # doesn't return. C.f. #5503. + # doesn't return. C.f. https://github.com/matrix-org/synapse/issues/5503. return [], max_token # Figure out which other users this user should explicitly receive diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c2109036ec..1027fbfd28 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from synapse.api.errors import ( AuthError, @@ -23,6 +23,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) +from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util.caches.descriptors import cached from synapse.util.stringutils import parse_and_validate_mxc_uri @@ -306,7 +307,9 @@ class ProfileHandler: server_name = host if self._is_mine_server_name(server_name): - media_info = await self.store.get_local_media(media_id) + media_info: Optional[ + Union[LocalMedia, RemoteMedia] + ] = await self.store.get_local_media(media_id) else: media_info = await self.store.get_cached_remote_media(server_name, media_id) @@ -322,12 +325,12 @@ class ProfileHandler: if self.max_avatar_size: # Ensure avatar does not exceed max allowed avatar size - if media_info["media_length"] > self.max_avatar_size: + if media_info.media_length > self.max_avatar_size: logger.warning( "Forbidding avatar change to %s: %d bytes is above the allowed size " "limit", mxc, - media_info["media_length"], + media_info.media_length, ) return False @@ -335,12 +338,12 @@ class ProfileHandler: # Ensure the avatar's file type is allowed if ( self.allowed_avatar_mimetypes - and media_info["media_type"] not in self.allowed_avatar_mimetypes + and media_info.media_type not in self.allowed_avatar_mimetypes ): logger.warning( "Forbidding avatar change to %s: mimetype %s not allowed", mxc, - media_info["media_type"], + media_info.media_type, ) return False diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 6d680b0795..afd8138caf 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -269,7 +269,7 @@ class RoomCreationHandler: self, requester: Requester, old_room_id: str, - old_room: Dict[str, Any], + old_room: Tuple[bool, str, bool], new_room_id: str, new_version: RoomVersion, tombstone_event: EventBase, @@ -279,7 +279,7 @@ class RoomCreationHandler: Args: requester: the user requesting the upgrade old_room_id: the id of the room to be replaced - old_room: a dict containing room information for the room to be replaced, + old_room: a tuple containing room information for the room to be replaced, as returned by `RoomWorkerStore.get_room`. new_room_id: the id of the replacement room new_version: the version to upgrade the room to @@ -299,7 +299,7 @@ class RoomCreationHandler: await self.store.store_room( room_id=new_room_id, room_creator_user_id=user_id, - is_public=old_room["is_public"], + is_public=old_room[0], room_version=new_version, ) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 36e2db8975..2947e154be 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py
@@ -33,6 +33,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.storage.databases.main.room import LargestRoomStats from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache @@ -170,26 +171,24 @@ class RoomListHandler: ignore_non_federatable=from_federation, ) - def build_room_entry(room: JsonDict) -> JsonDict: + def build_room_entry(room: LargestRoomStats) -> JsonDict: entry = { - "room_id": room["room_id"], - "name": room["name"], - "topic": room["topic"], - "canonical_alias": room["canonical_alias"], - "num_joined_members": room["joined_members"], - "avatar_url": room["avatar"], - "world_readable": room["history_visibility"] + "room_id": room.room_id, + "name": room.name, + "topic": room.topic, + "canonical_alias": room.canonical_alias, + "num_joined_members": room.joined_members, + "avatar_url": room.avatar, + "world_readable": room.history_visibility == HistoryVisibility.WORLD_READABLE, - "guest_can_join": room["guest_access"] == "can_join", - "join_rule": room["join_rules"], - "room_type": room["room_type"], + "guest_can_join": room.guest_access == "can_join", + "join_rule": room.join_rules, + "room_type": room.room_type, } # Filter out Nones – rather omit the field altogether return {k: v for k, v in entry.items() if v is not None} - results = [build_room_entry(r) for r in results] - response: JsonDict = {} num_results = len(results) if limit is not None: @@ -212,33 +211,33 @@ class RoomListHandler: # If there was a token given then we assume that there # must be previous results. response["prev_batch"] = RoomListNextBatch( - last_joined_members=initial_entry["num_joined_members"], - last_room_id=initial_entry["room_id"], + last_joined_members=initial_entry.joined_members, + last_room_id=initial_entry.room_id, direction_is_forward=False, ).to_token() if more_to_come: response["next_batch"] = RoomListNextBatch( - last_joined_members=final_entry["num_joined_members"], - last_room_id=final_entry["room_id"], + last_joined_members=final_entry.joined_members, + last_room_id=final_entry.room_id, direction_is_forward=True, ).to_token() else: if has_batch_token: response["next_batch"] = RoomListNextBatch( - last_joined_members=final_entry["num_joined_members"], - last_room_id=final_entry["room_id"], + last_joined_members=final_entry.joined_members, + last_room_id=final_entry.room_id, direction_is_forward=True, ).to_token() if more_to_come: response["prev_batch"] = RoomListNextBatch( - last_joined_members=initial_entry["num_joined_members"], - last_room_id=initial_entry["room_id"], + last_joined_members=initial_entry.joined_members, + last_room_id=initial_entry.room_id, direction_is_forward=False, ).to_token() - response["chunk"] = results + response["chunk"] = [build_room_entry(r) for r in results] response["total_room_count_estimate"] = await self.store.count_public_rooms( network_tuple, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 918eb203e2..eddc2af9ba 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Add new room to the room directory if the old room was there # Remove old room from the room directory old_room = await self.store.get_room(old_room_id) - if old_room is not None and old_room["is_public"]: + # If the old room exists and is public. + if old_room is not None and old_room[0]: await self.store.set_room_is_public(old_room_id, False) await self.store.set_room_is_public(room_id, True) diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index dd559b4c45..1dfb12e065 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py
@@ -703,24 +703,24 @@ class RoomSummaryHandler: # there should always be an entry assert stats is not None, "unable to retrieve stats for %s" % (room_id,) - entry = { - "room_id": stats["room_id"], - "name": stats["name"], - "topic": stats["topic"], - "canonical_alias": stats["canonical_alias"], - "num_joined_members": stats["joined_members"], - "avatar_url": stats["avatar"], - "join_rule": stats["join_rules"], + entry: JsonDict = { + "room_id": stats.room_id, + "name": stats.name, + "topic": stats.topic, + "canonical_alias": stats.canonical_alias, + "num_joined_members": stats.joined_members, + "avatar_url": stats.avatar, + "join_rule": stats.join_rules, "world_readable": ( - stats["history_visibility"] == HistoryVisibility.WORLD_READABLE + stats.history_visibility == HistoryVisibility.WORLD_READABLE ), - "guest_can_join": stats["guest_access"] == "can_join", - "room_type": stats["room_type"], + "guest_can_join": stats.guest_access == "can_join", + "room_type": stats.room_type, } if self._msc3266_enabled: - entry["im.nheko.summary.version"] = stats["version"] - entry["im.nheko.summary.encryption"] = stats["encryption"] + entry["im.nheko.summary.version"] = stats.version + entry["im.nheko.summary.encryption"] = stats.encryption # Federation requests need to provide additional information so the # requested server is able to filter the response appropriately. diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 62f2454f5d..389dc5298a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py
@@ -806,7 +806,7 @@ class SsoHandler: media_id = profile["avatar_url"].split("/")[-1] if self._is_mine_server_name(server_name): media = await self._media_repo.store.get_local_media(media_id) - if media is not None and upload_name == media["upload_name"]: + if media is not None and upload_name == media.upload_name: logger.info("skipping saving the user avatar") return True diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2f1bc5a015..bf0106c6e7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -399,7 +399,7 @@ class SyncHandler: # # If that happens, we mustn't cache it, so that when the client comes back # with the same cache token, we don't immediately return the same empty - # result, causing a tightloop. (#8518) + # result, causing a tightloop. (https://github.com/matrix-org/synapse/issues/8518) if result.next_batch == since_token: cache_context.should_cache = False @@ -1003,7 +1003,7 @@ class SyncHandler: # always make sure we LL ourselves so we know we're in the room # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209 # We only need apply this on full state syncs given we disabled - # LL for incr syncs in #3840. + # LL for incr syncs in https://github.com/matrix-org/synapse/pull/3840. # We don't insert ourselves into `members_to_fetch`, because in some # rare cases (an empty event batch with a now_token after the user's # leave in a partial state room which another local user has diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 75717ba4f9..3c19ea56f8 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py
@@ -184,8 +184,8 @@ class UserDirectoryHandler(StateDeltasHandler): """Called to update index of our local user profiles when they change irrespective of any rooms the user may be in. """ - # FIXME(#3714): We should probably do this in the same worker as all - # the other changes. + # FIXME(https://github.com/matrix-org/synapse/issues/3714): We should + # probably do this in the same worker as all the other changes. if await self.store.should_include_local_user_in_dir(user_id): await self.store.update_profile_in_user_dir( @@ -194,8 +194,8 @@ class UserDirectoryHandler(StateDeltasHandler): async def handle_local_user_deactivated(self, user_id: str) -> None: """Called when a user ID is deactivated""" - # FIXME(#3714): We should probably do this in the same worker as all - # the other changes. + # FIXME(https://github.com/matrix-org/synapse/issues/3714): We should + # probably do this in the same worker as all the other changes. await self.store.remove_from_user_dir(user_id) async def _unsafe_process(self) -> None: diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 08c7fc1631..d5013e8e97 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -465,7 +465,7 @@ class MatrixFederationHttpClient: """Wrapper for _send_request which can optionally retry the request upon receiving a combination of a 400 HTTP response code and a 'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3 - due to #3622. + due to https://github.com/matrix-org/synapse/issues/3622. Args: request: details of request to be sent @@ -958,9 +958,9 @@ class MatrixFederationHttpClient: requests). try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end - of the request. Workaround for #3622 in Synapse <= v0.99.3. This - will be attempted before backing off if backing off has been - enabled. + of the request. Workaround for https://github.com/matrix-org/synapse/issues/3622 + in Synapse <= v0.99.3. This will be attempted before backing off if + backing off has been enabled. parser: The parser to use to decode the response. Defaults to parsing as JSON. backoff_on_all_error_codes: Back off if we get any error response @@ -1155,7 +1155,8 @@ class MatrixFederationHttpClient: try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of - the request. Workaround for #3622 in Synapse <= v0.99.3. + the request. Workaround for https://github.com/matrix-org/synapse/issues/3622 + in Synapse <= v0.99.3. parser: The parser to use to decode the response. Defaults to parsing as JSON. @@ -1250,7 +1251,8 @@ class MatrixFederationHttpClient: try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of - the request. Workaround for #3622 in Synapse <= v0.99.3. + the request. Workaround for https://github.com/matrix-org/synapse/issues/3622 + in Synapse <= v0.99.3. parser: The parser to use to decode the response. Defaults to parsing as JSON. diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 4454fe29a5..e297fa9c8b 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -1019,11 +1019,14 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: if not opentracing: return func + # getfullargspec is somewhat expensive, so ensure it is only called a single + # time (the function signature shouldn't change anyway). + argspec = inspect.getfullargspec(func) + @contextlib.contextmanager def _wrapping_logic( - func: Callable[P, R], *args: P.args, **kwargs: P.kwargs + _func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: - argspec = inspect.getfullargspec(func) # We use `[1:]` to skip the `self` object reference and `start=1` to # make the index line up with `argspec.args`. # diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 860e5ddca2..9d88a711cf 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py
@@ -83,6 +83,12 @@ INLINE_CONTENT_TYPES = [ "audio/x-flac", ] +# Default timeout_ms for download and thumbnail requests +DEFAULT_MAX_TIMEOUT_MS = 20_000 + +# Maximum allowed timeout_ms for download and thumbnail requests +MAXIMUM_ALLOWED_MAX_TIMEOUT_MS = 60_000 + def respond_404(request: SynapseRequest) -> None: assert request.path is not None diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 72b0f1c5de..bf976b9e7c 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py
@@ -19,6 +19,7 @@ import shutil from io import BytesIO from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple +import attr from matrix_common.types.mxc_uri import MXCUri import twisted.internet.error @@ -26,13 +27,16 @@ import twisted.web.http from twisted.internet.defer import Deferred from synapse.api.errors import ( + Codes, FederationDeniedError, HttpResponseException, NotFoundError, RequestSendFailed, SynapseError, + cs_error, ) from synapse.config.repository import ThumbnailRequirement +from synapse.http.server import respond_with_json from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.logging.opentracing import trace @@ -50,6 +54,7 @@ from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia from synapse.types import UserID from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination @@ -78,6 +83,8 @@ class MediaRepository: self.store = hs.get_datastores().main self.max_upload_size = hs.config.media.max_upload_size self.max_image_pixels = hs.config.media.max_image_pixels + self.unused_expiration_time = hs.config.media.unused_expiration_time + self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads Thumbnailer.set_limits(self.max_image_pixels) @@ -184,6 +191,117 @@ class MediaRepository: self.recently_accessed_locals.add(media_id) @trace + async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]: + """Create and store a media ID for a local user and return the MXC URI and its + expiration. + + Args: + auth_user: The user_id of the uploader + + Returns: + A tuple containing the MXC URI of the stored content and the timestamp at + which the MXC URI expires. + """ + media_id = random_string(24) + now = self.clock.time_msec() + await self.store.store_local_media_id( + media_id=media_id, + time_now_ms=now, + user_id=auth_user, + ) + return f"mxc://{self.server_name}/{media_id}", now + self.unused_expiration_time + + @trace + async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]: + """Check if the user is over the limit for pending media uploads. + + Args: + auth_user: The user_id of the uploader + + Returns: + A tuple with a boolean and an integer indicating whether the user has too + many pending media uploads and the timestamp at which the first pending + media will expire, respectively. + """ + pending, first_expiration_ts = await self.store.count_pending_media( + user_id=auth_user + ) + return pending >= self.max_pending_media_uploads, first_expiration_ts + + @trace + async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None: + """Verify that the media ID can be uploaded to by the given user. This + function checks that: + + * the media ID exists + * the media ID does not already have content + * the user uploading is the same as the one who created the media ID + * the media ID has not expired + + Args: + media_id: The media ID to verify + auth_user: The user_id of the uploader + """ + media = await self.store.get_local_media(media_id) + if media is None: + raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND) + + if media.user_id != auth_user.to_string(): + raise SynapseError( + 403, + "Only the creator of the media ID can upload to it", + errcode=Codes.FORBIDDEN, + ) + + if media.media_length is not None: + raise SynapseError( + 409, + "Media ID already has content", + errcode=Codes.CANNOT_OVERWRITE_MEDIA, + ) + + expired_time_ms = self.clock.time_msec() - self.unused_expiration_time + if media.created_ts < expired_time_ms: + raise NotFoundError("Media ID has expired") + + @trace + async def update_content( + self, + media_id: str, + media_type: str, + upload_name: Optional[str], + content: IO, + content_length: int, + auth_user: UserID, + ) -> None: + """Update the content of the given media ID. + + Args: + media_id: The media ID to replace. + media_type: The content type of the file. + upload_name: The name of the file, if provided. + content: A file like object that is the content to store + content_length: The length of the content + auth_user: The user_id of the uploader + """ + file_info = FileInfo(server_name=None, file_id=media_id) + fname = await self.media_storage.store_file(content, file_info) + logger.info("Stored local media in file %r", fname) + + await self.store.update_local_media( + media_id=media_id, + media_type=media_type, + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + ) + + try: + await self._generate_thumbnails(None, media_id, media_id, media_type) + except Exception as e: + logger.info("Failed to generate thumbnails: %s", e) + + @trace async def create_content( self, media_type: str, @@ -229,8 +347,74 @@ class MediaRepository: return MXCUri(self.server_name, media_id) + def respond_not_yet_uploaded(self, request: SynapseRequest) -> None: + respond_with_json( + request, + 504, + cs_error("Media has not been uploaded yet", code=Codes.NOT_YET_UPLOADED), + send_cors=True, + ) + + async def get_local_media_info( + self, request: SynapseRequest, media_id: str, max_timeout_ms: int + ) -> Optional[LocalMedia]: + """Gets the info dictionary for given local media ID. If the media has + not been uploaded yet, this function will wait up to ``max_timeout_ms`` + milliseconds for the media to be uploaded. + + Args: + request: The incoming request. + media_id: The media ID of the content. (This is the same as + the file_id for local content.) + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. + + Returns: + Either the info dictionary for the given local media ID or + ``None``. If ``None``, then no further processing is necessary as + this function will send the necessary JSON response. + """ + wait_until = self.clock.time_msec() + max_timeout_ms + while True: + # Get the info for the media + media_info = await self.store.get_local_media(media_id) + if not media_info: + logger.info("Media %s is unknown", media_id) + respond_404(request) + return None + + if media_info.quarantined_by: + logger.info("Media %s is quarantined", media_id) + respond_404(request) + return None + + # The file has been uploaded, so stop looping + if media_info.media_length is not None: + return media_info + + # Check if the media ID has expired and still hasn't been uploaded to. + now = self.clock.time_msec() + expired_time_ms = now - self.unused_expiration_time + if media_info.created_ts < expired_time_ms: + logger.info("Media %s has expired without being uploaded", media_id) + respond_404(request) + return None + + if now >= wait_until: + break + + await self.clock.sleep(0.5) + + logger.info("Media %s has not yet been uploaded", media_id) + self.respond_not_yet_uploaded(request) + return None + async def get_local_media( - self, request: SynapseRequest, media_id: str, name: Optional[str] + self, + request: SynapseRequest, + media_id: str, + name: Optional[str], + max_timeout_ms: int, ) -> None: """Responds to requests for local media, if exists, or returns 404. @@ -240,23 +424,24 @@ class MediaRepository: the file_id for local content.) name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. Returns: Resolves once a response has successfully been written to request """ - media_info = await self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: - respond_404(request) + media_info = await self.get_local_media_info(request, media_id, max_timeout_ms) + if not media_info: return self.mark_recently_accessed(None, media_id) - media_type = media_info["media_type"] + media_type = media_info.media_type if not media_type: media_type = "application/octet-stream" - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] - url_cache = media_info["url_cache"] + media_length = media_info.media_length + upload_name = name if name else media_info.upload_name + url_cache = media_info.url_cache file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) @@ -271,6 +456,7 @@ class MediaRepository: server_name: str, media_id: str, name: Optional[str], + max_timeout_ms: int, ) -> None: """Respond to requests for remote media. @@ -280,6 +466,8 @@ class MediaRepository: media_id: The media ID of the content (as defined by the remote server). name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. Returns: Resolves once a response has successfully been written to request @@ -305,27 +493,33 @@ class MediaRepository: key = (server_name, media_id) async with self.remote_media_linearizer.queue(key): responder, media_info = await self._get_remote_media_impl( - server_name, media_id + server_name, media_id, max_timeout_ms ) # We deliberately stream the file outside the lock - if responder: - media_type = media_info["media_type"] - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] + if responder and media_info: + upload_name = name if name else media_info.upload_name await respond_with_responder( - request, responder, media_type, media_length, upload_name + request, + responder, + media_info.media_type, + media_info.media_length, + upload_name, ) else: respond_404(request) - async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: + async def get_remote_media_info( + self, server_name: str, media_id: str, max_timeout_ms: int + ) -> RemoteMedia: """Gets the media info associated with the remote file, downloading if necessary. Args: server_name: Remote server_name where the media originated. media_id: The media ID of the content (as defined by the remote server). + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. Returns: The media info of the file @@ -341,7 +535,7 @@ class MediaRepository: key = (server_name, media_id) async with self.remote_media_linearizer.queue(key): responder, media_info = await self._get_remote_media_impl( - server_name, media_id + server_name, media_id, max_timeout_ms ) # Ensure we actually use the responder so that it releases resources @@ -352,8 +546,8 @@ class MediaRepository: return media_info async def _get_remote_media_impl( - self, server_name: str, media_id: str - ) -> Tuple[Optional[Responder], dict]: + self, server_name: str, media_id: str, max_timeout_ms: int + ) -> Tuple[Optional[Responder], RemoteMedia]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -361,6 +555,8 @@ class MediaRepository: server_name: Remote server_name where the media originated. media_id: The media ID of the content (as defined by the remote server). + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. Returns: A tuple of responder and the media info of the file. @@ -373,15 +569,17 @@ class MediaRepository: # If we have an entry in the DB, try and look for it if media_info: - file_id = media_info["filesystem_id"] + file_id = media_info.filesystem_id file_info = FileInfo(server_name, file_id) - if media_info["quarantined_by"]: + if media_info.quarantined_by: logger.info("Media is quarantined") raise NotFoundError() - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" + if not media_info.media_type: + media_info = attr.evolve( + media_info, media_type="application/octet-stream" + ) responder = await self.media_storage.fetch_media(file_info) if responder: @@ -391,8 +589,7 @@ class MediaRepository: try: media_info = await self._download_remote_file( - server_name, - media_id, + server_name, media_id, max_timeout_ms ) except SynapseError: raise @@ -403,9 +600,9 @@ class MediaRepository: if not media_info: raise e - file_id = media_info["filesystem_id"] - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" + file_id = media_info.filesystem_id + if not media_info.media_type: + media_info = attr.evolve(media_info, media_type="application/octet-stream") file_info = FileInfo(server_name, file_id) # We generate thumbnails even if another process downloaded the media @@ -415,7 +612,7 @@ class MediaRepository: # otherwise they'll request thumbnails and get a 404 if they're not # ready yet. await self._generate_thumbnails( - server_name, media_id, file_id, media_info["media_type"] + server_name, media_id, file_id, media_info.media_type ) responder = await self.media_storage.fetch_media(file_info) @@ -425,7 +622,8 @@ class MediaRepository: self, server_name: str, media_id: str, - ) -> dict: + max_timeout_ms: int, + ) -> RemoteMedia: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -434,7 +632,8 @@ class MediaRepository: media_id: The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. - file_id: Local file ID + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. Returns: The media info of the file. @@ -458,7 +657,8 @@ class MediaRepository: # tell the remote server to 404 if it doesn't # recognise the server_name, to make sure we don't # end up with a routing loop. - "allow_remote": "false" + "allow_remote": "false", + "timeout_ms": str(max_timeout_ms), }, ) except RequestSendFailed as e: @@ -518,7 +718,7 @@ class MediaRepository: origin=server_name, media_id=media_id, media_type=media_type, - time_now_ms=self.clock.time_msec(), + time_now_ms=time_now_ms, upload_name=upload_name, media_length=length, filesystem_id=file_id, @@ -526,15 +726,17 @@ class MediaRepository: logger.info("Stored remote media in file %r", fname) - media_info = { - "media_type": media_type, - "media_length": length, - "upload_name": upload_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - } - - return media_info + return RemoteMedia( + media_origin=server_name, + media_id=media_id, + media_type=media_type, + media_length=length, + upload_name=upload_name, + created_ts=time_now_ms, + filesystem_id=file_id, + last_access_ts=time_now_ms, + quarantined_by=None, + ) def _get_thumbnail_requirements( self, media_type: str diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 9b5a3dd5f4..44aac21de6 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py
@@ -240,15 +240,14 @@ class UrlPreviewer: cache_result = await self.store.get_url_cache(url, ts) if ( cache_result - and cache_result["expires_ts"] > ts - and cache_result["response_code"] / 100 == 2 + and cache_result.expires_ts > ts + and cache_result.response_code // 100 == 2 ): # It may be stored as text in the database, not as bytes (such as # PostgreSQL). If so, encode it back before handing it on. - og = cache_result["og"] - if isinstance(og, str): - og = og.encode("utf8") - return og + if isinstance(cache_result.og, str): + return cache_result.og.encode("utf8") + return cache_result.og # If this URL can be accessed via an allowed oEmbed, use that instead. url_to_download = url diff --git a/synapse/metrics/_reactor_metrics.py b/synapse/metrics/_reactor_metrics.py
index a2c6e6842d..dd486dd3e2 100644 --- a/synapse/metrics/_reactor_metrics.py +++ b/synapse/metrics/_reactor_metrics.py
@@ -12,17 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -import select +import logging import time -from typing import Any, Iterable, List, Tuple +from selectors import SelectSelector, _PollLikeSelector # type: ignore[attr-defined] +from typing import Any, Callable, Iterable from prometheus_client import Histogram, Metric from prometheus_client.core import REGISTRY, GaugeMetricFamily -from twisted.internet import reactor +from twisted.internet import reactor, selectreactor +from twisted.internet.asyncioreactor import AsyncioSelectorReactor from synapse.metrics._types import Collector +try: + from selectors import KqueueSelector +except ImportError: + + class KqueueSelector: # type: ignore[no-redef] + pass + + +try: + from twisted.internet.epollreactor import EPollReactor +except ImportError: + + class EPollReactor: # type: ignore[no-redef] + pass + + +try: + from twisted.internet.pollreactor import PollReactor +except ImportError: + + class PollReactor: # type: ignore[no-redef] + pass + + +logger = logging.getLogger(__name__) + # # Twisted reactor metrics # @@ -34,52 +62,100 @@ tick_time = Histogram( ) -class EpollWrapper: - """a wrapper for an epoll object which records the time between polls""" +class CallWrapper: + """A wrapper for a callable which records the time between calls""" - def __init__(self, poller: "select.epoll"): # type: ignore[name-defined] + def __init__(self, wrapped: Callable[..., Any]): self.last_polled = time.time() - self._poller = poller + self._wrapped = wrapped - def poll(self, *args, **kwargs) -> List[Tuple[int, int]]: # type: ignore[no-untyped-def] - # record the time since poll() was last called. This gives a good proxy for + def __call__(self, *args, **kwargs) -> Any: # type: ignore[no-untyped-def] + # record the time since this was last called. This gives a good proxy for # how long it takes to run everything in the reactor - ie, how long anything # waiting for the next tick will have to wait. tick_time.observe(time.time() - self.last_polled) - ret = self._poller.poll(*args, **kwargs) + ret = self._wrapped(*args, **kwargs) self.last_polled = time.time() return ret + +class ObjWrapper: + """A wrapper for an object which wraps a specified method in CallWrapper. + + Other methods/attributes are passed to the original object. + + This is necessary when the wrapped object does not allow the attribute to be + overwritten. + """ + + def __init__(self, wrapped: Any, method_name: str): + self._wrapped = wrapped + self._method_name = method_name + self._wrapped_method = CallWrapper(getattr(wrapped, method_name)) + def __getattr__(self, item: str) -> Any: - return getattr(self._poller, item) + if item == self._method_name: + return self._wrapped_method + + return getattr(self._wrapped, item) class ReactorLastSeenMetric(Collector): - def __init__(self, epoll_wrapper: EpollWrapper): - self._epoll_wrapper = epoll_wrapper + def __init__(self, call_wrapper: CallWrapper): + self._call_wrapper = call_wrapper def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily( "python_twisted_reactor_last_seen", "Seconds since the Twisted reactor was last seen", ) - cm.add_metric([], time.time() - self._epoll_wrapper.last_polled) + cm.add_metric([], time.time() - self._call_wrapper.last_polled) yield cm +# Twisted has already select a reasonable reactor for us, so assumptions can be +# made about the shape. +wrapper = None try: - # if the reactor has a `_poller` attribute, which is an `epoll` object - # (ie, it's an EPollReactor), we wrap the `epoll` with a thing that will - # measure the time between ticks - from select import epoll # type: ignore[attr-defined] - - poller = reactor._poller # type: ignore[attr-defined] -except (AttributeError, ImportError): - pass -else: - if isinstance(poller, epoll): - poller = EpollWrapper(poller) - reactor._poller = poller # type: ignore[attr-defined] - REGISTRY.register(ReactorLastSeenMetric(poller)) + if isinstance(reactor, (PollReactor, EPollReactor)): + reactor._poller = ObjWrapper(reactor._poller, "poll") # type: ignore[attr-defined] + wrapper = reactor._poller._wrapped_method # type: ignore[attr-defined] + + elif isinstance(reactor, selectreactor.SelectReactor): + # Twisted uses a module-level _select function. + wrapper = selectreactor._select = CallWrapper(selectreactor._select) + + elif isinstance(reactor, AsyncioSelectorReactor): + # For asyncio look at the underlying asyncio event loop. + asyncio_loop = reactor._asyncioEventloop # A sub-class of BaseEventLoop, + + # A sub-class of BaseSelector. + selector = asyncio_loop._selector # type: ignore[attr-defined] + + if isinstance(selector, SelectSelector): + wrapper = selector._select = CallWrapper(selector._select) # type: ignore[attr-defined] + + # poll, epoll, and /dev/poll. + elif isinstance(selector, _PollLikeSelector): + selector._selector = ObjWrapper(selector._selector, "poll") # type: ignore[attr-defined] + wrapper = selector._selector._wrapped_method # type: ignore[attr-defined] + + elif isinstance(selector, KqueueSelector): + selector._selector = ObjWrapper(selector._selector, "control") # type: ignore[attr-defined] + wrapper = selector._selector._wrapped_method # type: ignore[attr-defined] + + else: + # E.g. this does not support the (Windows-only) ProactorEventLoop. + logger.warning( + "Skipping configuring ReactorLastSeenMetric: unexpected asyncio loop selector: %r via %r", + selector, + asyncio_loop, + ) +except Exception as e: + logger.warning("Configuring ReactorLastSeenMetric failed: %r", e) + + +if wrapper: + REGISTRY.register(ReactorLastSeenMetric(wrapper)) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 755c59274c..812144a128 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -1860,7 +1860,8 @@ class PublicRoomListManager: if not room: return False - return room.get("is_public", False) + # The first item is whether the room is public. + return room[0] async def add_room_to_public_room_list(self, room_id: str) -> None: """Publishes a room to the public room list. diff --git a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
index ecaeef3511..7419785aff 100644 --- a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py +++ b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
@@ -295,7 +295,8 @@ class ThirdPartyEventRulesModuleApiCallbacks: raise except SynapseError as e: # FIXME: Being able to throw SynapseErrors is relied upon by - # some modules. PR #10386 accidentally broke this ability. + # some modules. PR https://github.com/matrix-org/synapse/pull/10386 + # accidentally broke this ability. # That said, we aren't keen on exposing this implementation detail # to modules and we should one day have a proper way to do what # is wanted. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 14784312dc..5934b1ef34 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -25,10 +25,13 @@ from typing import ( Sequence, Tuple, Union, + cast, ) from prometheus_client import Counter +from twisted.internet.defer import Deferred + from synapse.api.constants import ( MAIN_TIMELINE, EventContentFields, @@ -40,11 +43,15 @@ from synapse.api.room_versions import PushRuleRoomFlag from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership +from synapse.storage.roommember import ProfileInfo from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator from synapse.types import JsonValue from synapse.types.state import StateFilter +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import gather_results from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state @@ -342,15 +349,41 @@ class BulkPushRuleEvaluator: rules_by_user = await self._get_rules_for_event(event) actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {} - room_member_count = await self.store.get_number_joined_users_in_room( - event.room_id - ) - + # Gather a bunch of info in parallel. + # + # This has a lot of ignored types and casting due to the use of @cached + # decorated functions passed into run_in_background. + # + # See https://github.com/matrix-org/synapse/issues/16606 ( - power_levels, - sender_power_level, - ) = await self._get_power_levels_and_sender_level( - event, context, event_id_to_event + room_member_count, + (power_levels, sender_power_level), + related_events, + profiles, + ) = await make_deferred_yieldable( + cast( + "Deferred[Tuple[int, Tuple[dict, Optional[int]], Dict[str, Dict[str, JsonValue]], Mapping[str, ProfileInfo]]]", + gather_results( + ( + run_in_background( # type: ignore[call-arg] + self.store.get_number_joined_users_in_room, event.room_id # type: ignore[arg-type] + ), + run_in_background( + self._get_power_levels_and_sender_level, + event, + context, + event_id_to_event, + ), + run_in_background(self._related_events, event), + run_in_background( # type: ignore[call-arg] + self.store.get_subset_users_in_room_with_profiles, + event.room_id, # type: ignore[arg-type] + rules_by_user.keys(), # type: ignore[arg-type] + ), + ), + consumeErrors=True, + ).addErrback(unwrapFirstError), + ) ) # Find the event's thread ID. @@ -366,8 +399,6 @@ class BulkPushRuleEvaluator: # the parent is part of a thread. thread_id = await self.store.get_thread_id(relation.parent_id) - related_events = await self._related_events(event) - # It's possible that old room versions have non-integer power levels (floats or # strings; even the occasional `null`). For old rooms, we interpret these as if # they were integers. Do this here for the `@room` power level threshold. @@ -400,11 +431,6 @@ class BulkPushRuleEvaluator: self.hs.config.experimental.msc1767_enabled, # MSC3931 flag ) - users = rules_by_user.keys() - profiles = await self.store.get_subset_users_in_room_with_profiles( - event.room_id, users - ) - for uid, rules in rules_by_user.items(): if event.sender == uid: continue diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index afd03137f0..c14a18ba2e 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -257,6 +257,11 @@ class ReplicationCommandHandler: if hs.config.redis.redis_enabled: self._notifier.add_lock_released_callback(self.on_lock_released) + # Marks if we should send POSITION commands for all streams ASAP. This + # is checked by the `ReplicationStreamer` which manages sending + # RDATA/POSITION commands + self._should_announce_positions = True + def subscribe_to_channel(self, channel_name: str) -> None: """ Indicates that we wish to subscribe to a Redis channel by name. @@ -397,29 +402,23 @@ class ReplicationCommandHandler: return self._streams_to_replicate def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None: - self.send_positions_to_connection(conn) + self.send_positions_to_connection() - def send_positions_to_connection(self, conn: IReplicationConnection) -> None: + def send_positions_to_connection(self) -> None: """Send current position of all streams this process is source of to the connection. """ - # We respond with current position of all streams this instance - # replicates. - for stream in self.get_streams_to_replicate(): - # Note that we use the current token as the prev token here (rather - # than stream.last_token), as we can't be sure that there have been - # no rows written between last token and the current token (since we - # might be racing with the replication sending bg process). - current_token = stream.current_token(self._instance_name) - self.send_command( - PositionCommand( - stream.NAME, - self._instance_name, - current_token, - current_token, - ) - ) + self._should_announce_positions = True + self._notifier.notify_replication() + + def should_announce_positions(self) -> bool: + """Check if we should send POSITION commands for all streams ASAP.""" + return self._should_announce_positions + + def will_announce_positions(self) -> None: + """Mark that we're about to send POSITIONs out for all streams.""" + self._should_announce_positions = False def on_USER_SYNC( self, conn: IReplicationConnection, cmd: UserSyncCommand @@ -588,6 +587,21 @@ class ReplicationCommandHandler: logger.debug("Handling '%s %s'", cmd.NAME, cmd.to_line()) + # Check if we can early discard this position. We can only do so for + # connected streams. + stream = self._streams[cmd.stream_name] + if stream.can_discard_position( + cmd.instance_name, cmd.prev_token, cmd.new_token + ) and self.is_stream_connected(conn, cmd.stream_name): + logger.debug( + "Discarding redundant POSITION %s/%s %s %s", + cmd.instance_name, + cmd.stream_name, + cmd.prev_token, + cmd.new_token, + ) + return + self._add_command_to_stream_queue(conn, cmd) async def _process_position( @@ -599,6 +613,18 @@ class ReplicationCommandHandler: """ stream = self._streams[stream_name] + if stream.can_discard_position( + cmd.instance_name, cmd.prev_token, cmd.new_token + ) and self.is_stream_connected(conn, cmd.stream_name): + logger.debug( + "Discarding redundant POSITION %s/%s %s %s", + cmd.instance_name, + cmd.stream_name, + cmd.prev_token, + cmd.new_token, + ) + return + # We're about to go and catch up with the stream, so remove from set # of connected streams. for streams in self._streams_by_connection.values(): @@ -626,8 +652,9 @@ class ReplicationCommandHandler: # for why this can happen. logger.info( - "Fetching replication rows for '%s' between %i and %i", + "Fetching replication rows for '%s' / %s between %i and %i", stream_name, + cmd.instance_name, current_token, cmd.new_token, ) @@ -657,6 +684,13 @@ class ReplicationCommandHandler: self._streams_by_connection.setdefault(conn, set()).add(stream_name) + def is_stream_connected( + self, conn: IReplicationConnection, stream_name: str + ) -> bool: + """Return if stream has been successfully connected and is ready to + receive updates""" + return stream_name in self._streams_by_connection.get(conn, ()) + def on_REMOTE_SERVER_UP( self, conn: IReplicationConnection, cmd: RemoteServerUpCommand ) -> None: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 7e96145b3b..1fa37bb888 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -141,7 +141,7 @@ class RedisSubscriber(SubscriberProtocol): # We send out our positions when there is a new connection in case the # other side missed updates. We do this for Redis connections as the # otherside won't know we've connected and so won't issue a REPLICATE. - self.synapse_handler.send_positions_to_connection(self) + self.synapse_handler.send_positions_to_connection() def messageReceived(self, pattern: str, channel: str, message: str) -> None: """Received a message from redis.""" diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 38abb5df54..d15828f2d3 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -123,7 +123,7 @@ class ReplicationStreamer: # We check up front to see if anything has actually changed, as we get # poked because of changes that happened on other instances. - if all( + if not self.command_handler.should_announce_positions() and all( stream.last_token == stream.current_token(self._instance_name) for stream in self.streams ): @@ -158,6 +158,21 @@ class ReplicationStreamer: all_streams = list(all_streams) random.shuffle(all_streams) + if self.command_handler.should_announce_positions(): + # We need to send out POSITIONs for all streams, usually + # because a worker has reconnected. + self.command_handler.will_announce_positions() + + for stream in all_streams: + self.command_handler.send_command( + PositionCommand( + stream.NAME, + self._instance_name, + stream.last_token, + stream.last_token, + ) + ) + for stream in all_streams: if stream.last_token == stream.current_token( self._instance_name diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 58a44029aa..cc34dfb322 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -144,6 +144,16 @@ class Stream: """ raise NotImplementedError() + def can_discard_position( + self, instance_name: str, prev_token: int, new_token: int + ) -> bool: + """Whether or not a position command for this stream can be discarded. + + Useful for streams that can never go backwards and where we already know + the stream ID for the instance has advanced. + """ + return False + def discard_updates_and_advance(self) -> None: """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. @@ -221,6 +231,14 @@ class _StreamFromIdGen(Stream): def minimal_local_current_token(self) -> Token: return self._stream_id_gen.get_minimal_local_current_token() + def can_discard_position( + self, instance_name: str, prev_token: int, new_token: int + ) -> bool: + # These streams can't go backwards, so we know we can ignore any + # positions where the tokens are from before the current token. + + return new_token <= self.current_token(instance_name) + def current_token_without_instance( current_token: Callable[[], int] diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 9bd0d764f8..91edfd45d7 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py
@@ -88,6 +88,7 @@ from synapse.rest.admin.users import ( UserByThreePid, UserMembershipRestServlet, UserRegisterServlet, + UserReplaceMasterCrossSigningKeyRestServlet, UserRestServletV2, UsersRestServletV2, UserTokenRestServlet, @@ -292,6 +293,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListDestinationsRestServlet(hs).register(http_server) RoomMessagesRestServlet(hs).register(http_server) RoomTimestampToEventRestServlet(hs).register(http_server) + UserReplaceMasterCrossSigningKeyRestServlet(hs).register(http_server) UserByExternalId(hs).register(http_server) UserByThreePid(hs).register(http_server) diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index b7637dff0b..8cf5268854 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py
@@ -17,6 +17,8 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple +import attr + from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer @@ -418,7 +420,7 @@ class UserMediaRestServlet(RestServlet): start, limit, user_id, order_by, direction ) - ret = {"media": media, "total": total} + ret = {"media": [attr.asdict(m) for m in media], "total": total} if (start + limit) < total: ret["next_token"] = start + len(media) @@ -477,7 +479,7 @@ class UserMediaRestServlet(RestServlet): ) deleted_media, total = await self.media_repository.delete_local_media_ids( - [row["media_id"] for row in media] + [m.media_id for m in media] ) return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index ffce92d45e..f3e06d3da3 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py
@@ -77,7 +77,18 @@ class ListRegistrationTokensRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) valid = parse_boolean(request, "valid") token_list = await self.store.get_registration_tokens(valid) - return HTTPStatus.OK, {"registration_tokens": token_list} + return HTTPStatus.OK, { + "registration_tokens": [ + { + "token": t[0], + "uses_allowed": t[1], + "pending": t[2], + "completed": t[3], + "expiry_time": t[4], + } + for t in token_list + ] + } class NewRegistrationTokenRestServlet(RestServlet): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 0659f22a89..7e40bea8aa 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -16,6 +16,8 @@ from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple, cast from urllib import parse as urlparse +import attr + from synapse.api.constants import Direction, EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.filtering import Filter @@ -306,10 +308,13 @@ class RoomRestServlet(RestServlet): raise NotFoundError("Room not found") members = await self.store.get_users_in_room(room_id) - ret["joined_local_devices"] = await self.store.count_devices_by_users(members) - ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id) + result = attr.asdict(ret) + result["joined_local_devices"] = await self.store.count_devices_by_users( + members + ) + result["forgotten"] = await self.store.is_locally_forgotten_room(room_id) - return HTTPStatus.OK, ret + return HTTPStatus.OK, result async def on_DELETE( self, request: SynapseRequest, room_id: str @@ -408,8 +413,8 @@ class RoomMembersRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - ret = await self.store.get_room(room_id) - if not ret: + room = await self.store.get_room(room_id) + if not room: raise NotFoundError("Room not found") members = await self.store.get_users_in_room(room_id) @@ -437,8 +442,8 @@ class RoomStateRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - ret = await self.store.get_room(room_id) - if not ret: + room = await self.store.get_room(room_id) + if not room: raise NotFoundError("Room not found") event_ids = await self._storage_controllers.state.get_current_state_ids(room_id) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 7fe16130e7..9900498fbe 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -18,6 +18,8 @@ import secrets from http import HTTPStatus from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +import attr + from synapse.api.constants import Direction, UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( @@ -161,11 +163,13 @@ class UsersRestServletV2(RestServlet): ) # If support for MSC3866 is not enabled, don't show the approval flag. + filter = None if not self._msc3866_enabled: - for user in users: - del user["approved"] - ret = {"users": users, "total": total} + def _filter(a: attr.Attribute) -> bool: + return a.name != "approved" + + ret = {"users": [attr.asdict(u, filter=filter) for u in users], "total": total} if (start + limit) < total: ret["next_token"] = str(start + len(users)) @@ -1266,6 +1270,46 @@ class AccountDataRestServlet(RestServlet): } +class UserReplaceMasterCrossSigningKeyRestServlet(RestServlet): + """Allow a given user to replace their master cross-signing key without UIA. + + This replacement is permitted for a limited period (currently 10 minutes). + + While this is exposed via the admin API, this is intended for use by the + Matrix Authentication Service rather than server admins. + """ + + PATTERNS = admin_patterns( + "/users/(?P<user_id>[^/]*)/_allow_cross_signing_replacement_without_uia" + ) + REPLACEMENT_PERIOD_MS = 10 * 60 * 1000 # 10 minutes + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_POST( + self, + request: SynapseRequest, + user_id: str, + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + if user_id is None: + raise NotFoundError("User not found") + + timestamp = ( + await self._store.allow_master_cross_signing_key_replacement_without_uia( + user_id, self.REPLACEMENT_PERIOD_MS + ) + ) + + if timestamp is None: + raise NotFoundError("User has no master cross-signing key") + + return HTTPStatus.OK, {"updatable_without_uia_before_ms": timestamp} + + class UserByExternalId(RestServlet): """Find a user based on an external ID from an auth provider""" diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 641390cb30..0c0e82627d 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py
@@ -299,19 +299,16 @@ class DeactivateAccountRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) - # allow ASes to deactivate their own users - if requester.app_service: - await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), body.erase, requester + # allow ASes to deactivate their own users: + # ASes don't need user-interactive auth + if not requester.app_service: + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body.dict(exclude_unset=True), + "deactivate your account", ) - return 200, {} - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body.dict(exclude_unset=True), - "deactivate your account", - ) result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), body.erase, requester, id_server=body.id_server ) diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index 82944ca711..3534c3c259 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py
@@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet): if room is None: raise NotFoundError("Unknown room") - return 200, {"visibility": "public" if room["is_public"] else "private"} + return 200, {"visibility": "public" if room[0] else "private"} class PutBody(RequestBodyModel): visibility: Literal["public", "private"] = "public" diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 70b8be1aa2..add8045439 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py
@@ -376,9 +376,10 @@ class SigningKeyUploadServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) - is_cross_signing_setup = ( - await self.e2e_keys_handler.is_cross_signing_set_up_for_user(user_id) - ) + ( + is_cross_signing_setup, + master_key_updatable_without_uia, + ) = await self.e2e_keys_handler.check_cross_signing_setup(user_id) # Before MSC3967 we required UIA both when setting up cross signing for the # first time and when resetting the device signing key. With MSC3967 we only @@ -386,9 +387,14 @@ class SigningKeyUploadServlet(RestServlet): # time. Because there is no UIA in MSC3861, for now we throw an error if the # user tries to reset the device signing key when MSC3861 is enabled, but allow # first-time setup. + # + # XXX: We now have a get-out clause by which MAS can temporarily mark the master + # key as replaceable. It should do its own equivalent of user interactive auth + # before doing so. if self.hs.config.experimental.msc3861.enabled: - # There is no way to reset the device signing key with MSC3861 - if is_cross_signing_setup: + # The auth service has to explicitly mark the master key as replaceable + # without UIA to reset the device signing key with MSC3861. + if is_cross_signing_setup and not master_key_updatable_without_uia: raise SynapseError( HTTPStatus.NOT_IMPLEMENTED, "Resetting cross signing keys is not yet supported with MSC3861", diff --git a/synapse/rest/media/create_resource.py b/synapse/rest/media/create_resource.py new file mode 100644
index 0000000000..994afdf13c --- /dev/null +++ b/synapse/rest/media/create_resource.py
@@ -0,0 +1,83 @@ +# Copyright 2023 Beeper Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from typing import TYPE_CHECKING + +from synapse.api.errors import LimitExceededError +from synapse.api.ratelimiting import Ratelimiter +from synapse.http.server import respond_with_json +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.media.media_repository import MediaRepository + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class CreateResource(RestServlet): + PATTERNS = [re.compile("/_matrix/media/v1/create")] + + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): + super().__init__() + + self.media_repo = media_repo + self.clock = hs.get_clock() + self.auth = hs.get_auth() + self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads + + # A rate limiter for creating new media IDs. + self._create_media_rate_limiter = Ratelimiter( + store=hs.get_datastores().main, + clock=self.clock, + cfg=hs.config.ratelimiting.rc_media_create, + ) + + async def on_POST(self, request: SynapseRequest) -> None: + requester = await self.auth.get_user_by_req(request) + + # If the create media requests for the user are over the limit, drop them. + await self._create_media_rate_limiter.ratelimit(requester) + + ( + reached_pending_limit, + first_expiration_ts, + ) = await self.media_repo.reached_pending_media_limit(requester.user) + if reached_pending_limit: + raise LimitExceededError( + limiter_name="max_pending_media_uploads", + retry_after_ms=first_expiration_ts - self.clock.time_msec(), + ) + + content_uri, unused_expires_at = await self.media_repo.create_media_id( + requester.user + ) + + logger.info( + "Created Media URI %r that if unused will expire at %d", + content_uri, + unused_expires_at, + ) + respond_with_json( + request, + 200, + { + "content_uri": content_uri, + "unused_expires_at": unused_expires_at, + }, + send_cors=True, + ) diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 65b9ff52fa..60cd87548c 100644 --- a/synapse/rest/media/download_resource.py +++ b/synapse/rest/media/download_resource.py
@@ -17,9 +17,13 @@ import re from typing import TYPE_CHECKING, Optional from synapse.http.server import set_corp_headers, set_cors_headers -from synapse.http.servlet import RestServlet, parse_boolean +from synapse.http.servlet import RestServlet, parse_boolean, parse_integer from synapse.http.site import SynapseRequest -from synapse.media._base import respond_404 +from synapse.media._base import ( + DEFAULT_MAX_TIMEOUT_MS, + MAXIMUM_ALLOWED_MAX_TIMEOUT_MS, + respond_404, +) from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: @@ -65,12 +69,16 @@ class DownloadResource(RestServlet): ) # Limited non-standard form of CSP for IE11 request.setHeader(b"X-Content-Security-Policy", b"sandbox;") - request.setHeader( - b"Referrer-Policy", - b"no-referrer", + request.setHeader(b"Referrer-Policy", b"no-referrer") + max_timeout_ms = parse_integer( + request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS ) + max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS) + if self._is_mine_server_name(server_name): - await self.media_repo.get_local_media(request, media_id, file_name) + await self.media_repo.get_local_media( + request, media_id, file_name, max_timeout_ms + ) else: allow_remote = parse_boolean(request, "allow_remote", default=True) if not allow_remote: @@ -83,5 +91,5 @@ class DownloadResource(RestServlet): return await self.media_repo.get_remote_media( - request, server_name, media_id, file_name + request, server_name, media_id, file_name, max_timeout_ms ) diff --git a/synapse/rest/media/media_repository_resource.py b/synapse/rest/media/media_repository_resource.py
index 2089bb1029..ca65116b84 100644 --- a/synapse/rest/media/media_repository_resource.py +++ b/synapse/rest/media/media_repository_resource.py
@@ -18,10 +18,11 @@ from synapse.config._base import ConfigError from synapse.http.server import HttpServer, JsonResource from .config_resource import MediaConfigResource +from .create_resource import CreateResource from .download_resource import DownloadResource from .preview_url_resource import PreviewUrlResource from .thumbnail_resource import ThumbnailResource -from .upload_resource import UploadResource +from .upload_resource import AsyncUploadServlet, UploadServlet if TYPE_CHECKING: from synapse.server import HomeServer @@ -91,8 +92,9 @@ class MediaRepositoryResource(JsonResource): # Note that many of these should not exist as v1 endpoints, but empirically # a lot of traffic still goes to them. - - UploadResource(hs, media_repo).register(http_server) + CreateResource(hs, media_repo).register(http_server) + UploadServlet(hs, media_repo).register(http_server) + AsyncUploadServlet(hs, media_repo).register(http_server) DownloadResource(hs, media_repo).register(http_server) ThumbnailResource(hs, media_repo, media_repo.media_storage).register( http_server diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 85b6bdbe72..681f2a5a27 100644 --- a/synapse/rest/media/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py
@@ -23,6 +23,8 @@ from synapse.http.server import respond_with_json, set_corp_headers, set_cors_he from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.media._base import ( + DEFAULT_MAX_TIMEOUT_MS, + MAXIMUM_ALLOWED_MAX_TIMEOUT_MS, FileInfo, ThumbnailInfo, respond_404, @@ -75,15 +77,19 @@ class ThumbnailResource(RestServlet): method = parse_string(request, "method", "scale") # TODO Parse the Accept header to get an prioritised list of thumbnail types. m_type = "image/png" + max_timeout_ms = parse_integer( + request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS + ) + max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS) if self._is_mine_server_name(server_name): if self.dynamic_thumbnails: await self._select_or_generate_local_thumbnail( - request, media_id, width, height, method, m_type + request, media_id, width, height, method, m_type, max_timeout_ms ) else: await self._respond_local_thumbnail( - request, media_id, width, height, method, m_type + request, media_id, width, height, method, m_type, max_timeout_ms ) self.media_repo.mark_recently_accessed(None, media_id) else: @@ -95,14 +101,21 @@ class ThumbnailResource(RestServlet): respond_404(request) return - if self.dynamic_thumbnails: - await self._select_or_generate_remote_thumbnail( - request, server_name, media_id, width, height, method, m_type - ) - else: - await self._respond_remote_thumbnail( - request, server_name, media_id, width, height, method, m_type - ) + remote_resp_function = ( + self._select_or_generate_remote_thumbnail + if self.dynamic_thumbnails + else self._respond_remote_thumbnail + ) + await remote_resp_function( + request, + server_name, + media_id, + width, + height, + method, + m_type, + max_timeout_ms, + ) self.media_repo.mark_recently_accessed(server_name, media_id) async def _respond_local_thumbnail( @@ -113,15 +126,12 @@ class ThumbnailResource(RestServlet): height: int, method: str, m_type: str, + max_timeout_ms: int, ) -> None: - media_info = await self.store.get_local_media(media_id) - + media_info = await self.media_repo.get_local_media_info( + request, media_id, max_timeout_ms + ) if not media_info: - respond_404(request) - return - if media_info["quarantined_by"]: - logger.info("Media is quarantined") - respond_404(request) return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) @@ -134,7 +144,7 @@ class ThumbnailResource(RestServlet): thumbnail_infos, media_id, media_id, - url_cache=bool(media_info["url_cache"]), + url_cache=bool(media_info.url_cache), server_name=None, ) @@ -146,15 +156,13 @@ class ThumbnailResource(RestServlet): desired_height: int, desired_method: str, desired_type: str, + max_timeout_ms: int, ) -> None: - media_info = await self.store.get_local_media(media_id) + media_info = await self.media_repo.get_local_media_info( + request, media_id, max_timeout_ms + ) if not media_info: - respond_404(request) - return - if media_info["quarantined_by"]: - logger.info("Media is quarantined") - respond_404(request) return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) @@ -168,7 +176,7 @@ class ThumbnailResource(RestServlet): file_info = FileInfo( server_name=None, file_id=media_id, - url_cache=media_info["url_cache"], + url_cache=bool(media_info.url_cache), thumbnail=info, ) @@ -188,7 +196,7 @@ class ThumbnailResource(RestServlet): desired_height, desired_method, desired_type, - url_cache=bool(media_info["url_cache"]), + url_cache=bool(media_info.url_cache), ) if file_path: @@ -206,14 +214,20 @@ class ThumbnailResource(RestServlet): desired_height: int, desired_method: str, desired_type: str, + max_timeout_ms: int, ) -> None: - media_info = await self.media_repo.get_remote_media_info(server_name, media_id) + media_info = await self.media_repo.get_remote_media_info( + server_name, media_id, max_timeout_ms + ) + if not media_info: + respond_404(request) + return thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) - file_id = media_info["filesystem_id"] + file_id = media_info.filesystem_id for info in thumbnail_infos: t_w = info.width == desired_width @@ -224,7 +238,7 @@ class ThumbnailResource(RestServlet): if t_w and t_h and t_method and t_type: file_info = FileInfo( server_name=server_name, - file_id=media_info["filesystem_id"], + file_id=file_id, thumbnail=info, ) @@ -263,11 +277,16 @@ class ThumbnailResource(RestServlet): height: int, method: str, m_type: str, + max_timeout_ms: int, ) -> None: # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. - media_info = await self.media_repo.get_remote_media_info(server_name, media_id) + media_info = await self.media_repo.get_remote_media_info( + server_name, media_id, max_timeout_ms + ) + if not media_info: + return thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id @@ -280,7 +299,7 @@ class ThumbnailResource(RestServlet): m_type, thumbnail_infos, media_id, - media_info["filesystem_id"], + media_info.filesystem_id, url_cache=False, server_name=server_name, ) diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py
index 949326d85d..62d3e228a8 100644 --- a/synapse/rest/media/upload_resource.py +++ b/synapse/rest/media/upload_resource.py
@@ -15,7 +15,7 @@ import logging import re -from typing import IO, TYPE_CHECKING, Dict, List, Optional +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Tuple from synapse.api.errors import Codes, SynapseError from synapse.http.server import respond_with_json @@ -29,23 +29,24 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# The name of the lock to use when uploading media. +_UPLOAD_MEDIA_LOCK_NAME = "upload_media" -class UploadResource(RestServlet): - PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload")] +class BaseUploadServlet(RestServlet): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo self.filepaths = media_repo.filepaths self.store = hs.get_datastores().main - self.clock = hs.get_clock() + self.server_name = hs.hostname self.auth = hs.get_auth() self.max_upload_size = hs.config.media.max_upload_size - self.clock = hs.get_clock() - async def on_POST(self, request: SynapseRequest) -> None: - requester = await self.auth.get_user_by_req(request) + def _get_file_metadata( + self, request: SynapseRequest + ) -> Tuple[int, Optional[str], str]: raw_content_length = request.getHeader("Content-Length") if raw_content_length is None: raise SynapseError(msg="Request must specify a Content-Length", code=400) @@ -88,6 +89,16 @@ class UploadResource(RestServlet): # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion + return content_length, upload_name, media_type + + +class UploadServlet(BaseUploadServlet): + PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload$")] + + async def on_POST(self, request: SynapseRequest) -> None: + requester = await self.auth.get_user_by_req(request) + content_length, upload_name, media_type = self._get_file_metadata(request) + try: content: IO = request.content # type: ignore content_uri = await self.media_repo.create_content( @@ -103,3 +114,53 @@ class UploadResource(RestServlet): respond_with_json( request, 200, {"content_uri": str(content_uri)}, send_cors=True ) + + +class AsyncUploadServlet(BaseUploadServlet): + PATTERNS = [ + re.compile( + "/_matrix/media/v3/upload/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$" + ) + ] + + async def on_PUT( + self, request: SynapseRequest, server_name: str, media_id: str + ) -> None: + requester = await self.auth.get_user_by_req(request) + + if server_name != self.server_name: + raise SynapseError( + 404, + "Non-local server name specified", + errcode=Codes.NOT_FOUND, + ) + + lock = await self.store.try_acquire_lock(_UPLOAD_MEDIA_LOCK_NAME, media_id) + if not lock: + raise SynapseError( + 409, + "Media ID cannot be overwritten", + errcode=Codes.CANNOT_OVERWRITE_MEDIA, + ) + + async with lock: + await self.media_repo.verify_can_upload(media_id, requester.user) + content_length, upload_name, media_type = self._get_file_metadata(request) + + try: + content: IO = request.content # type: ignore + await self.media_repo.update_content( + media_id, + media_type, + upload_name, + content, + content_length, + requester.user, + ) + except SpamMediaException: + # For uploading of media we want to respond with a 400, instead of + # the default 404, as that would just be confusing. + raise SynapseError(400, "Bad content") + + logger.info("Uploaded content for media ID %r", media_id) + respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 12829d3d7d..62fbd05534 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -28,6 +28,7 @@ from typing import ( Sequence, Tuple, Type, + cast, ) import attr @@ -48,7 +49,11 @@ else: if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.database import DatabasePool, LoggingTransaction + from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + ) logger = logging.getLogger(__name__) @@ -488,14 +493,14 @@ class BackgroundUpdater: True if we have finished running all the background updates, otherwise False """ - def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]: + def get_background_updates_txn(txn: Cursor) -> List[Tuple[str, Optional[str]]]: txn.execute( """ SELECT update_name, depends_on FROM background_updates ORDER BY ordering, update_name """ ) - return self.db_pool.cursor_to_dict(txn) + return cast(List[Tuple[str, Optional[str]]], txn.fetchall()) if not self._current_background_update: all_pending_updates = await self.db_pool.runInteraction( @@ -507,14 +512,13 @@ class BackgroundUpdater: return True # find the first update which isn't dependent on another one in the queue. - pending = {update["update_name"] for update in all_pending_updates} - for upd in all_pending_updates: - depends_on = upd["depends_on"] + pending = {update_name for update_name, depends_on in all_pending_updates} + for update_name, depends_on in all_pending_updates: if not depends_on or depends_on not in pending: break logger.info( "Not starting on bg update %s until %s is done", - upd["update_name"], + update_name, depends_on, ) else: @@ -524,7 +528,7 @@ class BackgroundUpdater: "another: dependency cycle?" ) - self._current_background_update = upd["update_name"] + self._current_background_update = update_name # We have a background update to run, otherwise we would have returned # early. @@ -746,10 +750,10 @@ class BackgroundUpdater: The named index will be dropped upon completion of the new index. """ - def create_index_psql(conn: Connection) -> None: + def create_index_psql(conn: "LoggingDatabaseConnection") -> None: conn.rollback() # postgres insists on autocommit for the index - conn.set_session(autocommit=True) # type: ignore + conn.engine.attempt_to_set_autocommit(conn.conn, True) try: c = conn.cursor() @@ -793,9 +797,9 @@ class BackgroundUpdater: undo_timeout_sql = f"SET statement_timeout = {default_timeout}" conn.cursor().execute(undo_timeout_sql) - conn.set_session(autocommit=False) # type: ignore + conn.engine.attempt_to_set_autocommit(conn.conn, False) - def create_index_sqlite(conn: Connection) -> None: + def create_index_sqlite(conn: "LoggingDatabaseConnection") -> None: # Sqlite doesn't support concurrent creation of indexes. # # We assume that sqlite doesn't give us invalid indices; however @@ -825,7 +829,9 @@ class BackgroundUpdater: c.execute(sql) if isinstance(self.db_pool.engine, engines.PostgresEngine): - runner: Optional[Callable[[Connection], None]] = create_index_psql + runner: Optional[ + Callable[[LoggingDatabaseConnection], None] + ] = create_index_psql elif psql_only: runner = None else: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index f39ae2d635..1529c86cc5 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py
@@ -542,13 +542,15 @@ class EventsPersistenceStorageController: return await res.get_state(self._state_controller, StateFilter.all()) async def _persist_event_batch( - self, _room_id: str, task: _PersistEventsTask + self, room_id: str, task: _PersistEventsTask ) -> Dict[str, str]: """Callback for the _event_persist_queue Calculates the change to current state and forward extremities, and persists the given events and with those updates. + Assumes that we are only persisting events for one room at a time. + Returns: A dictionary of event ID to event ID we didn't persist as we already had another event persisted with the same TXN ID. @@ -594,140 +596,23 @@ class EventsPersistenceStorageController: # We can't easily parallelize these since different chunks # might contain the same event. :( - # NB: Assumes that we are only persisting events for one room - # at a time. - - # map room_id->set[event_ids] giving the new forward - # extremities in each room - new_forward_extremities: Dict[str, Set[str]] = {} - - # map room_id->(to_delete, to_insert) where to_delete is a list - # of type/state keys to remove from current state, and to_insert - # is a map (type,key)->event_id giving the state delta in each - # room - state_delta_for_room: Dict[str, DeltaState] = {} + new_forward_extremities = None + state_delta_for_room = None if not backfilled: with Measure(self._clock, "_calculate_state_and_extrem"): - # Work out the new "current state" for each room. + # Work out the new "current state" for the room. # We do this by working out what the new extremities are and then # calculating the state from that. - events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) - - for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = ( - await self.main_store.get_latest_event_ids_in_room(room_id) - ) - new_latest_event_ids = await self._calculate_new_extremities( - room_id, ev_ctx_rm, latest_event_ids - ) - - if new_latest_event_ids == latest_event_ids: - # No change in extremities, so no change in state - continue - - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremities[room_id] = new_latest_event_ids - - len_1 = ( - len(latest_event_ids) == 1 - and len(new_latest_event_ids) == 1 - ) - if len_1: - all_single_prev_not_state = all( - len(event.prev_event_ids()) == 1 - and not event.is_state() - for event, ctx in ev_ctx_rm - ) - # Don't bother calculating state if they're just - # a long chain of single ancestor non-state events. - if all_single_prev_not_state: - continue - - state_delta_counter.inc() - if len(new_latest_event_ids) == 1: - state_delta_single_event_counter.inc() - - # This is a fairly handwavey check to see if we could - # have guessed what the delta would have been when - # processing one of these events. - # What we're interested in is if the latest extremities - # were the same when we created the event as they are - # now. When this server creates a new event (as opposed - # to receiving it over federation) it will use the - # forward extremities as the prev_events, so we can - # guess this by looking at the prev_events and checking - # if they match the current forward extremities. - for ev, _ in ev_ctx_rm: - prev_event_ids = set(ev.prev_event_ids()) - if latest_event_ids == prev_event_ids: - state_delta_reuse_delta_counter.inc() - break - - logger.debug("Calculating state delta for room %s", room_id) - with Measure( - self._clock, "persist_events.get_new_state_after_events" - ): - res = await self._get_new_state_after_events( - room_id, - ev_ctx_rm, - latest_event_ids, - new_latest_event_ids, - ) - current_state, delta_ids, new_latest_event_ids = res - - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremities[room_id] = new_latest_event_ids - - # If either are not None then there has been a change, - # and we need to work out the delta (or use that - # given) - delta = None - if delta_ids is not None: - # If there is a delta we know that we've - # only added or replaced state, never - # removed keys entirely. - delta = DeltaState([], delta_ids) - elif current_state is not None: - with Measure( - self._clock, "persist_events.calculate_state_delta" - ): - delta = await self._calculate_state_delta( - room_id, current_state - ) - - if delta: - # If we have a change of state then lets check - # whether we're actually still a member of the room, - # or if our last user left. If we're no longer in - # the room then we delete the current state and - # extremities. - is_still_joined = await self._is_server_still_joined( - room_id, - ev_ctx_rm, - delta, - ) - if not is_still_joined: - logger.info("Server no longer in room %s", room_id) - delta.no_longer_in_room = True - - state_delta_for_room[room_id] = delta + ( + new_forward_extremities, + state_delta_for_room, + ) = await self._calculate_new_forward_extremities_and_state_delta( + room_id, chunk + ) await self.persist_events_store._persist_events_and_state_updates( + room_id, chunk, state_delta_for_room=state_delta_for_room, new_forward_extremities=new_forward_extremities, @@ -737,6 +622,117 @@ class EventsPersistenceStorageController: return replaced_events + async def _calculate_new_forward_extremities_and_state_delta( + self, room_id: str, ev_ctx_rm: List[Tuple[EventBase, EventContext]] + ) -> Tuple[Optional[Set[str]], Optional[DeltaState]]: + """Calculates the new forward extremities and state delta for a room + given events to persist. + + Assumes that we are only persisting events for one room at a time. + + Returns: + A tuple of: + A set of str giving the new forward extremities the room + + The state delta for the room. + """ + + latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id) + new_latest_event_ids = await self._calculate_new_extremities( + room_id, ev_ctx_rm, latest_event_ids + ) + + if new_latest_event_ids == latest_event_ids: + # No change in extremities, so no change in state + return (None, None) + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremities = new_latest_event_ids + + len_1 = len(latest_event_ids) == 1 and len(new_latest_event_ids) == 1 + if len_1: + all_single_prev_not_state = all( + len(event.prev_event_ids()) == 1 and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + return (new_forward_extremities, None) + + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(ev.prev_event_ids()) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.debug("Calculating state delta for room %s", room_id) + with Measure(self._clock, "persist_events.get_new_state_after_events"): + res = await self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, + ) + current_state, delta_ids, new_latest_event_ids = res + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremities = new_latest_event_ids + + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + delta = None + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + delta = DeltaState([], delta_ids) + elif current_state is not None: + with Measure(self._clock, "persist_events.calculate_state_delta"): + delta = await self._calculate_state_delta(room_id, current_state) + + if delta: + # If we have a change of state then lets check + # whether we're actually still a member of the room, + # or if our last user left. If we're no longer in + # the room then we delete the current state and + # extremities. + is_still_joined = await self._is_server_still_joined( + room_id, + ev_ctx_rm, + delta, + ) + if not is_still_joined: + logger.info("Server no longer in room %s", room_id) + delta.no_longer_in_room = True + + return (new_forward_extremities, delta) + async def _calculate_new_extremities( self, room_id: str, diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a4e7048368..eb34de4df5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -18,7 +18,6 @@ import logging import time import types from collections import defaultdict -from sys import intern from time import monotonic as monotonic_time from typing import ( TYPE_CHECKING, @@ -1042,20 +1041,6 @@ class DatabasePool: self._db_pool.runWithConnection(inner_func, *args, **kwargs) ) - @staticmethod - def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]: - """Converts a SQL cursor into an list of dicts. - - Args: - cursor: The DBAPI cursor which has executed a query. - Returns: - A list of dicts where the key is the column header. - """ - assert cursor.description is not None, "cursor.description was None" - col_headers = [intern(str(column[0])) for column in cursor.description] - results = [dict(zip(col_headers, row)) for row in cursor] - return results - async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]: """Runs a single query for a result set. @@ -1131,8 +1116,8 @@ class DatabasePool: def simple_insert_many_txn( txn: LoggingTransaction, table: str, - keys: Collection[str], - values: Iterable[Iterable[Any]], + keys: Sequence[str], + values: Collection[Iterable[Any]], ) -> None: """Executes an INSERT query on the named table. @@ -1145,6 +1130,9 @@ class DatabasePool: keys: list of column names values: for each row, a list of values in the same order as `keys` """ + # If there's nothing to insert, then skip executing the query. + if not values: + return if isinstance(txn.database_engine, PostgresEngine): # We use `execute_values` as it can be a lot faster than `execute_batch`, @@ -1416,12 +1404,12 @@ class DatabasePool: allvalues.update(values) latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % ( + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %sDO %s" % ( table, ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues), ", ".join(k for k in keyvalues), - f"WHERE {where_clause}" if where_clause else "", + f"WHERE {where_clause} " if where_clause else "", latter, ) txn.execute(sql, list(allvalues.values())) @@ -1470,7 +1458,7 @@ class DatabasePool: key_names: Collection[str], key_values: Collection[Iterable[Any]], value_names: Collection[str], - value_values: Iterable[Iterable[Any]], + value_values: Collection[Iterable[Any]], ) -> None: """ Upsert, many times. @@ -1483,6 +1471,19 @@ class DatabasePool: value_values: A list of each row's value column values. Ignored if value_names is empty. """ + # If there's nothing to upsert, then skip executing the query. + if not key_values: + return + + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + if not value_names: + value_values = [() for x in range(len(key_values))] + elif len(value_values) != len(key_values): + raise ValueError( + f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number." + ) + if table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( txn, table, key_names, key_values, value_names, value_values @@ -1517,10 +1518,6 @@ class DatabasePool: value_values: A list of each row's value column values. Ignored if value_names is empty. """ - # No value columns, therefore make a blank list so that the following - # zip() works correctly. - if not value_names: - value_values = [() for x in range(len(key_values))] # Lock the table just once, to prevent it being done once per row. # Note that, according to Postgres' documentation, once obtained, @@ -1558,10 +1555,7 @@ class DatabasePool: allnames.extend(value_names) if not value_names: - # No value columns, therefore make a blank list so that the - # following zip() works correctly. latter = "NOTHING" - value_values = [() for x in range(len(key_values))] else: latter = "UPDATE SET " + ", ".join( k + "=EXCLUDED." + k for k in value_names @@ -1603,7 +1597,7 @@ class DatabasePool: retcols: Collection[str], allow_none: Literal[False] = False, desc: str = "simple_select_one", - ) -> Dict[str, Any]: + ) -> Tuple[Any, ...]: ... @overload @@ -1614,7 +1608,7 @@ class DatabasePool: retcols: Collection[str], allow_none: Literal[True] = True, desc: str = "simple_select_one", - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: ... async def simple_select_one( @@ -1624,7 +1618,7 @@ class DatabasePool: retcols: Collection[str], allow_none: bool = False, desc: str = "simple_select_one", - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -1925,6 +1919,7 @@ class DatabasePool: Returns: The results as a list of tuples. """ + # If there's nothing to select, then skip executing the query. if not iterable: return [] @@ -2059,13 +2054,13 @@ class DatabasePool: raise ValueError( f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number." ) + # If there is nothing to update, then skip executing the query. + if not key_values: + return # List of tuples of (value values, then key values) # (This matches the order needed for the query) - args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)] - - for ks, vs in zip(key_values, value_values): - args.append(tuple(vs) + tuple(ks)) + args = [tuple(vv) + tuple(kv) for vv, kv in zip(value_values, key_values)] # 'col1 = ?, col2 = ?, ...' set_clause = ", ".join(f"{n} = ?" for n in value_names) @@ -2077,9 +2072,7 @@ class DatabasePool: where_clause = "" # UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ? - sql = f""" - UPDATE {table} SET {set_clause} {where_clause} - """ + sql = f"UPDATE {table} SET {set_clause} {where_clause}" txn.execute_batch(sql, args) @@ -2134,7 +2127,7 @@ class DatabasePool: keyvalues: Dict[str, Any], retcols: Collection[str], allow_none: bool = False, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) if keyvalues: @@ -2152,7 +2145,7 @@ class DatabasePool: if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - return dict(zip(retcols, row)) + return row async def simple_delete_one( self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" @@ -2295,11 +2288,10 @@ class DatabasePool: Returns: Number rows deleted """ + # If there's nothing to delete, then skip executing the query. if not values: return 0 - sql = "DELETE FROM %s" % table - clause, values = make_in_list_sql_clause(txn.database_engine, column, values) clauses = [clause] @@ -2307,8 +2299,7 @@ class DatabasePool: clauses.append("%s = ?" % (key,)) values.append(value) - if clauses: - sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) + sql = "DELETE FROM %s WHERE %s" % (table, " AND ".join(clauses)) txn.execute(sql, values) return txn.rowcount diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 7aa24ccf21..b57e260fe0 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py
@@ -45,7 +45,7 @@ class Databases(Generic[DataStoreT]): """ databases: List[DatabasePool] - main: "DataStore" # FIXME: #11165: actually an instance of `main_store_class` + main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class` state: StateGroupDataStore persist_events: Optional[PersistEventsStore] diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 840d725114..89f4077351 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -17,6 +17,8 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast +import attr + from synapse.api.constants import Direction from synapse.config.homeserver import HomeServerConfig from synapse.storage._base import make_in_list_sql_clause @@ -28,7 +30,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import get_domain_from_id from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore @@ -82,6 +84,25 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UserPaginateResponse: + """This is very similar to UserInfo, but not quite the same.""" + + name: str + user_type: Optional[str] + is_guest: bool + admin: bool + deactivated: bool + shadow_banned: bool + displayname: Optional[str] + avatar_url: Optional[str] + creation_ts: Optional[int] + approved: bool + erased: bool + last_seen_ts: int + locked: bool + + class DataStore( EventsBackgroundUpdatesStore, ExperimentalFeaturesStore, @@ -156,7 +177,7 @@ class DataStore( approved: bool = True, not_user_types: Optional[List[str]] = None, locked: bool = False, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[UserPaginateResponse], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. @@ -182,7 +203,7 @@ class DataStore( def get_users_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[UserPaginateResponse], int]: filters = [] args: list = [] @@ -282,13 +303,24 @@ class DataStore( """ args += [limit, start] txn.execute(sql, args) - users = self.db_pool.cursor_to_dict(txn) - - # some of those boolean values are returned as integers when we're on SQLite - columns_to_boolify = ["erased"] - for user in users: - for column in columns_to_boolify: - user[column] = bool(user[column]) + users = [ + UserPaginateResponse( + name=row[0], + user_type=row[1], + is_guest=bool(row[2]), + admin=bool(row[3]), + deactivated=bool(row[4]), + shadow_banned=bool(row[5]), + displayname=row[6], + avatar_url=row[7], + creation_ts=row[8], + approved=bool(row[9]), + erased=bool(row[10]), + last_seen_ts=row[11], + locked=bool(row[12]), + ) + for row in txn + ] return users, count diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index d7482a1f4e..07f9b65af3 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -747,8 +747,16 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) # Invalidate the cache for any ignored users which were added or removed. - for ignored_user_id in previously_ignored_users ^ currently_ignored_users: - self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) + self._invalidate_cache_and_stream_bulk( + txn, + self.ignored_by, + [ + (ignored_user_id,) + for ignored_user_id in ( + previously_ignored_users ^ currently_ignored_users + ) + ], + ) self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) async def remove_account_data_for_user( @@ -824,10 +832,14 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) # Invalidate the cache for ignored users which were removed. - for ignored_user_id in previously_ignored_users: - self._invalidate_cache_and_stream( - txn, self.ignored_by, (ignored_user_id,) - ) + self._invalidate_cache_and_stream_bulk( + txn, + self.ignored_by, + [ + (ignored_user_id,) + for ignored_user_id in previously_ignored_users + ], + ) # Invalidate for this user the cache tracking ignored users. self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 4d0470ffd9..d7232f566b 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + def _invalidate_cache_and_stream_bulk( + self, + txn: LoggingTransaction, + cache_func: CachedFunction, + key_tuples: Collection[Tuple[Any, ...]], + ) -> None: + """A bulk version of _invalidate_cache_and_stream. + + Locally invalidate every key-tuple in `key_tuples`, then emit invalidations + for each key-tuple over replication. + + This implementation is more efficient than a loop which repeatedly calls the + non-bulk version. + """ + if not key_tuples: + return + + for keys in key_tuples: + txn.call_after(cache_func.invalidate, keys) + + self._send_invalidation_to_replication_bulk( + txn, cache_func.__name__, key_tuples + ) + def _invalidate_all_cache_and_stream( self, txn: LoggingTransaction, cache_func: CachedFunction ) -> None: @@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if isinstance(self.database_engine, PostgresEngine): assert self._cache_id_gen is not None - # get_next() returns a context manager which is designed to wrap - # the transaction. However, we want to only get an ID when we want - # to use it, here, so we need to call __enter__ manually, and have - # __exit__ called after the transaction finishes. stream_id = self._cache_id_gen.get_next_txn(txn) txn.call_after(self.hs.get_notifier().on_new_replication_data) @@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore): }, ) + def _send_invalidation_to_replication_bulk( + self, + txn: LoggingTransaction, + cache_name: str, + key_tuples: Collection[Tuple[Any, ...]], + ) -> None: + """Announce the invalidation of multiple (but not all) cache entries. + + This is more efficient than repeated calls to the non-bulk version. It should + NOT be used to invalidating the entire cache: use + `_send_invalidation_to_replication` with keys=None. + + Note that this does *not* invalidate the cache locally. + + Args: + txn + cache_name + key_tuples: Key-tuples to invalidate. Assumed to be non-empty. + """ + if isinstance(self.database_engine, PostgresEngine): + assert self._cache_id_gen is not None + + stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples)) + ts = self._clock.time_msec() + txn.call_after(self.hs.get_notifier().on_new_replication_data) + self.db_pool.simple_insert_many_txn( + txn, + table="cache_invalidation_stream_by_instance", + keys=( + "stream_id", + "instance_name", + "cache_func", + "keys", + "invalidation_ts", + ), + values=[ + # We convert key_tuples to a list here because psycopg2 serialises + # lists as pq arrrays, but serialises tuples as "composite types". + # (We need an array because the `keys` column has type `[]text`.) + # See: + # https://www.psycopg.org/docs/usage.html#adapt-list + # https://www.psycopg.org/docs/usage.html#adapt-tuple + (stream_id, self._instance_name, cache_name, list(key_tuple), ts) + for stream_id, key_tuple in zip(stream_ids, key_tuples) + ], + ) + def get_cache_stream_token_for_writer(self, instance_name: str) -> int: if self._cache_id_gen: return self._cache_id_gen.get_current_token_for_writer(instance_name) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 3e7425d4a6..02dddd1da4 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -450,14 +450,12 @@ class DeviceInboxWorkerStore(SQLBaseStore): user_id: str, device_id: Optional[str], up_to_stream_id: int, - limit: Optional[int] = None, ) -> int: """ Args: user_id: The recipient user_id. device_id: The recipient device_id. up_to_stream_id: Where to delete messages up to. - limit: maximum number of messages to delete Returns: The number of messages deleted. @@ -478,32 +476,22 @@ class DeviceInboxWorkerStore(SQLBaseStore): log_kv({"message": "No changes in cache since last check"}) return 0 - def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: - limit_statement = "" if limit is None else f"LIMIT {limit}" - sql = f""" - DELETE FROM device_inbox WHERE user_id = ? AND device_id = ? AND stream_id <= ( - SELECT MAX(stream_id) FROM ( - SELECT stream_id FROM device_inbox - WHERE user_id = ? AND device_id = ? AND stream_id <= ? - ORDER BY stream_id - {limit_statement} - ) AS q1 - ) - """ - txn.execute(sql, (user_id, device_id, user_id, device_id, up_to_stream_id)) - return txn.rowcount - - count = await self.db_pool.runInteraction( - "delete_messages_for_device", delete_messages_for_device_txn - ) + from_stream_id = None + count = 0 + while True: + from_stream_id, loop_count = await self.delete_messages_for_device_between( + user_id, + device_id, + from_stream_id=from_stream_id, + to_stream_id=up_to_stream_id, + limit=1000, + ) + count += loop_count + if from_stream_id is None: + break log_kv({"message": f"deleted {count} messages for device", "count": count}) - # In this case we don't know if we hit the limit or the delete is complete - # so let's not update the cache. - if count == limit: - return count - # Update the cache, ensuring that we only ever increase the value updated_last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), 0 @@ -515,6 +503,74 @@ class DeviceInboxWorkerStore(SQLBaseStore): return count @trace + async def delete_messages_for_device_between( + self, + user_id: str, + device_id: Optional[str], + from_stream_id: Optional[int], + to_stream_id: int, + limit: int, + ) -> Tuple[Optional[int], int]: + """Delete N device messages between the stream IDs, returning the + highest stream ID deleted (or None if all messages in the range have + been deleted) and the number of messages deleted. + + This is more efficient than `delete_messages_for_device` when calling in + a loop to batch delete messages. + """ + + # Keeping track of a lower bound of stream ID where we've deleted + # everything below makes the queries much faster. Otherwise, every time + # we scan for rows to delete we'd re-scan across all the rows that have + # previously deleted (until the next table VACUUM). + + if from_stream_id is None: + # Minimum device stream ID is 1. + from_stream_id = 0 + + def delete_messages_for_device_between_txn( + txn: LoggingTransaction, + ) -> Tuple[Optional[int], int]: + txn.execute( + """ + SELECT MAX(stream_id) FROM ( + SELECT stream_id FROM device_inbox + WHERE user_id = ? AND device_id = ? + AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id + LIMIT ? + ) AS d + """, + (user_id, device_id, from_stream_id, to_stream_id, limit), + ) + row = txn.fetchone() + if row is None or row[0] is None: + return None, 0 + + (max_stream_id,) = row + + txn.execute( + """ + DELETE FROM device_inbox + WHERE user_id = ? AND device_id = ? + AND ? < stream_id AND stream_id <= ? + """, + (user_id, device_id, from_stream_id, max_stream_id), + ) + + num_deleted = txn.rowcount + if num_deleted < limit: + return None, num_deleted + + return max_stream_id, num_deleted + + return await self.db_pool.runInteraction( + "delete_messages_for_device_between", + delete_messages_for_device_between_txn, + db_autocommit=True, # We don't need to run in a transaction + ) + + @trace async def get_new_device_msgs_for_remote( self, destination: str, last_stream_id: int, current_stream_id: int, limit: int ) -> Tuple[List[JsonDict], int]: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 49edbb9e06..775abbac79 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): A dict containing the device information, or `None` if the device does not exist. """ - return await self.db_pool.simple_select_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, - retcols=("user_id", "device_id", "display_name"), - desc="get_device", - allow_none=True, - ) - - async def get_device_opt( - self, user_id: str, device_id: str - ) -> Optional[Dict[str, Any]]: - """Retrieve a device. Only returns devices that are not marked as - hidden. - - Args: - user_id: The ID of the user which owns the device - device_id: The ID of the device to retrieve - Returns: - A dict containing the device information, or None if the device does not exist. - """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", allow_none=True, ) + if row is None: + return None + return {"user_id": row[0], "device_id": row[1], "display_name": row[2]} async def get_devices_by_user( self, user_id: str @@ -703,7 +686,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): key_names=("destination", "user_id"), key_values=[(destination, user_id) for user_id, _ in rows], value_names=("stream_id",), - value_values=((stream_id,) for _, stream_id in rows), + value_values=[(stream_id,) for _, stream_id in rows], ) # Delete all sent outbound pokes @@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): retcols=["device_id", "device_data"], allow_none=True, ) - return ( - (row["device_id"], json_decoder.decode(row["device_data"])) if row else None - ) + return (row[0], json_decoder.decode(row[1])) if row else None def _store_dehydrated_device_txn( self, @@ -1620,7 +1601,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): # # For each duplicate, we delete all the existing rows and put one back. - KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] last_row = progress.get( "last_row", {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, @@ -1628,44 +1608,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): def _txn(txn: LoggingTransaction) -> int: clause, args = make_tuple_comparison_clause( - [(x, last_row[x]) for x in KEY_COLS] + [ + ("stream_id", last_row["stream_id"]), + ("destination", last_row["destination"]), + ("user_id", last_row["user_id"]), + ("device_id", last_row["device_id"]), + ] ) - sql = """ + sql = f""" SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts FROM device_lists_outbound_pokes - WHERE %s - GROUP BY %s + WHERE {clause} + GROUP BY stream_id, destination, user_id, device_id HAVING count(*) > 1 - ORDER BY %s + ORDER BY stream_id, destination, user_id, device_id LIMIT ? - """ % ( - clause, # WHERE - ",".join(KEY_COLS), # GROUP BY - ",".join(KEY_COLS), # ORDER BY - ) + """ txn.execute(sql, args + [batch_size]) - rows = self.db_pool.cursor_to_dict(txn) + rows = txn.fetchall() - row = None - for row in rows: + stream_id, destination, user_id, device_id = None, None, None, None + for stream_id, destination, user_id, device_id, _ in rows: self.db_pool.simple_delete_txn( txn, "device_lists_outbound_pokes", - {x: row[x] for x in KEY_COLS}, + { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + }, ) - row["sent"] = False self.db_pool.simple_insert_txn( txn, "device_lists_outbound_pokes", - row, + { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + "sent": False, + }, ) - if row: + if rows: self.db_pool.updates._background_update_progress_txn( txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, - {"last_row": row}, + { + "last_row": { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + } + }, ) return len(rows) @@ -2309,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): `FALSE` have not been converted. """ - row = await self.db_pool.simple_select_one( - table="device_lists_changes_converted_stream_position", - keyvalues={}, - retcols=["stream_id", "room_id"], - desc="get_device_change_last_converted_pos", + return cast( + Tuple[int, str], + await self.db_pool.simple_select_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + retcols=["stream_id", "room_id"], + desc="get_device_change_last_converted_pos", + ), ) - return row["stream_id"], row["room_id"] async def set_device_change_last_converted_pos( self, diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index ad904a26a6..fae23c3407 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): # it isn't there. raise StoreError(404, "No backup with that version exists") - result = self.db_pool.simple_select_one_txn( - txn, - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data", "etag"), - allow_none=False, + row = cast( + Tuple[int, str, str, Optional[int]], + self.db_pool.simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + "deleted": 0, + }, + retcols=("version", "algorithm", "auth_data", "etag"), + allow_none=False, + ), ) - assert result is not None # see comment on `simple_select_one_txn` - result["auth_data"] = db_to_json(result["auth_data"]) - result["version"] = str(result["version"]) - if result["etag"] is None: - result["etag"] = 0 - return result + return { + "auth_data": db_to_json(row[2]), + "version": str(row[0]), + "algorithm": row[1], + "etag": 0 if row[3] is None else row[3], + } return await self.db_pool.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4f96ac25c7..9e98729330 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker for user_id, device_id, algorithm, key_id, key_json in claimed_keys: device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) - - if (user_id, device_id) in seen_user_device: - continue seen_user_device.add((user_id, device_id)) - self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) - ) + + self._invalidate_cache_and_stream_bulk( + txn, self.get_e2e_unused_fallback_key_types, seen_user_device + ) return results @@ -1268,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if row is None: continue - key_id = row["key_id"] - key_json = row["key_json"] - used = row["used"] + key_id, key_json, used = row # Mark fallback key as used if not already. if not used and mark_as_used: @@ -1376,17 +1372,62 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list) ) - seen_user_device: Set[Tuple[str, str]] = set() - for user_id, device_id, _, _, _ in otk_rows: - if (user_id, device_id) in seen_user_device: - continue - seen_user_device.add((user_id, device_id)) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) + seen_user_device = { + (user_id, device_id) for user_id, device_id, _, _, _ in otk_rows + } + self._invalidate_cache_and_stream_bulk( + txn, + self.count_e2e_one_time_keys, + seen_user_device, + ) return otk_rows + async def get_master_cross_signing_key_updatable_before( + self, user_id: str + ) -> Tuple[bool, Optional[int]]: + """Get time before which a master cross-signing key may be replaced without UIA. + + (UIA means "User-Interactive Auth".) + + There are three cases to distinguish: + (1) No master cross-signing key. + (2) The key exists, but there is no replace-without-UI timestamp in the DB. + (3) The key exists, and has such a timestamp recorded. + + Returns: a 2-tuple of: + - a boolean: is there a master cross-signing key already? + - an optional timestamp, directly taken from the DB. + + In terms of the cases above, these are: + (1) (False, None). + (2) (True, None). + (3) (True, <timestamp in ms>). + + """ + + def impl(txn: LoggingTransaction) -> Tuple[bool, Optional[int]]: + # We want to distinguish between three cases: + txn.execute( + """ + SELECT updatable_without_uia_before_ms + FROM e2e_cross_signing_keys + WHERE user_id = ? AND keytype = 'master' + ORDER BY stream_id DESC + LIMIT 1 + """, + (user_id,), + ) + row = cast(Optional[Tuple[Optional[int]]], txn.fetchone()) + if row is None: + return False, None + return True, row[0] + + return await self.db_pool.runInteraction( + "e2e_cross_signing_keys", + impl, + ) + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def __init__( @@ -1634,3 +1675,42 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ], desc="add_e2e_signing_key", ) + + async def allow_master_cross_signing_key_replacement_without_uia( + self, user_id: str, duration_ms: int + ) -> Optional[int]: + """Mark this user's latest master key as being replaceable without UIA. + + Said replacement will only be permitted for a short time after calling this + function. That time period is controlled by the duration argument. + + Returns: + None, if there is no such key. + Otherwise, the timestamp before which replacement is allowed without UIA. + """ + timestamp = self._clock.time_msec() + duration_ms + + def impl(txn: LoggingTransaction) -> Optional[int]: + txn.execute( + """ + UPDATE e2e_cross_signing_keys + SET updatable_without_uia_before_ms = ? + WHERE stream_id = ( + SELECT stream_id + FROM e2e_cross_signing_keys + WHERE user_id = ? AND keytype = 'master' + ORDER BY stream_id DESC + LIMIT 1 + ) + """, + (timestamp, user_id), + ) + if txn.rowcount == 0: + return None + + return timestamp + + return await self.db_pool.runInteraction( + "allow_master_cross_signing_key_replacement_without_uia", + impl, + ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index f1b0991503..7e992ca4a2 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Check if we have indexed the room so we can use the chain cover # algorithm. room = await self.get_room(room_id) # type: ignore[attr-defined] - if room["has_auth_chain_index"]: + # If the room has an auth chain index. + if room[1]: try: return await self.db_pool.runInteraction( "get_auth_chain_ids_chains", @@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Check if we have indexed the room so we can use the chain cover # algorithm. room = await self.get_room(room_id) # type: ignore[attr-defined] - if room["has_auth_chain_index"]: + # If the room has an auth chain index. + if room[1]: try: return await self.db_pool.runInteraction( "get_auth_chain_difference_chains", @@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) if event_lookup_result is not None: + event_type, depth, stream_ordering = event_lookup_result logger.debug( "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s", room_id, seed_event_id, - event_lookup_result["depth"], - event_lookup_result["stream_ordering"], - event_lookup_result["type"], + depth, + stream_ordering, + event_type, ) - if event_lookup_result["depth"]: - queue.put( - ( - -event_lookup_result["depth"], - -event_lookup_result["stream_ordering"], - seed_event_id, - event_lookup_result["type"], - ) - ) + if depth: + queue.put((-depth, -stream_ordering, seed_event_id, event_type)) while not queue.empty() and len(event_id_results) < limit: try: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3c1492e3ad..5207cc0f4e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -79,7 +79,7 @@ class DeltaState: Attributes: to_delete: List of type/state_keys to delete from current state to_insert: Map of state to upsert into current state - no_longer_in_room: The server is not longer in the room, so the room + no_longer_in_room: The server is no longer in the room, so the room should e.g. be removed from `current_state_events` table. """ @@ -131,22 +131,25 @@ class PersistEventsStore: @trace async def _persist_events_and_state_updates( self, + room_id: str, events_and_contexts: List[Tuple[EventBase, EventContext]], *, - state_delta_for_room: Dict[str, DeltaState], - new_forward_extremities: Dict[str, Set[str]], + state_delta_for_room: Optional[DeltaState], + new_forward_extremities: Optional[Set[str]], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, ) -> None: """Persist a set of events alongside updates to the current state and - forward extremities tables. + forward extremities tables. + + Assumes that we are only persisting events for one room at a time. Args: + room_id: events_and_contexts: - state_delta_for_room: Map from room_id to the delta to apply to - room state - new_forward_extremities: Map from room_id to set of event IDs - that are the new forward extremities of the room. + state_delta_for_room: The delta to apply to the room state + new_forward_extremities: A set of event IDs that are the new forward + extremities of the room. use_negative_stream_ordering: Whether to start stream_ordering on the negative side and decrement. This should be set as True for backfilled events because backfilled events get a negative @@ -196,6 +199,7 @@ class PersistEventsStore: await self.db_pool.runInteraction( "persist_events", self._persist_events_txn, + room_id=room_id, events_and_contexts=events_and_contexts, inhibit_local_membership_updates=inhibit_local_membership_updates, state_delta_for_room=state_delta_for_room, @@ -221,9 +225,9 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - for room_id, latest_event_ids in new_forward_extremities.items(): + if new_forward_extremities: self.store.get_latest_event_ids_in_room.prefill( - (room_id,), frozenset(latest_event_ids) + (room_id,), frozenset(new_forward_extremities) ) async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: @@ -336,10 +340,11 @@ class PersistEventsStore: self, txn: LoggingTransaction, *, + room_id: str, events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool, - state_delta_for_room: Dict[str, DeltaState], - new_forward_extremities: Dict[str, Set[str]], + state_delta_for_room: Optional[DeltaState], + new_forward_extremities: Optional[Set[str]], ) -> None: """Insert some number of room events into the necessary database tables. @@ -347,8 +352,11 @@ class PersistEventsStore: and the rejections table. Things reading from those table will need to check whether the event was rejected. + Assumes that we are only persisting events for one room at a time. + Args: txn + room_id: The room the events are from events_and_contexts: events to persist inhibit_local_membership_updates: Stop the local_current_membership from being updated by these events. This should be set to True @@ -357,10 +365,9 @@ class PersistEventsStore: delete_existing True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. - state_delta_for_room: The current-state delta for each room. - new_forward_extremities: The new forward extremities for each room. - For each room, a list of the event ids which are the forward - extremities. + state_delta_for_room: The current-state delta for the room. + new_forward_extremities: The new forward extremities for the room: + a set of the event ids which are the forward extremities. Raises: PartialStateConflictError: if attempting to persist a partial state event in @@ -376,14 +383,13 @@ class PersistEventsStore: # # Annoyingly SQLite doesn't support row level locking. if isinstance(self.database_engine, PostgresEngine): - for room_id in {e.room_id for e, _ in events_and_contexts}: - txn.execute( - "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE", - (room_id,), - ) - row = txn.fetchone() - if row is None: - raise Exception(f"Room does not exist {room_id}") + txn.execute( + "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE", + (room_id,), + ) + row = txn.fetchone() + if row is None: + raise Exception(f"Room does not exist {room_id}") # stream orderings should have been assigned by now assert min_stream_order @@ -419,7 +425,9 @@ class PersistEventsStore: events_and_contexts ) - self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts) + self._update_room_depths_txn( + txn, room_id, events_and_contexts=events_and_contexts + ) # _update_outliers_txn filters out any events which have already been # persisted, and returns the filtered list. @@ -432,11 +440,13 @@ class PersistEventsStore: self._store_event_txn(txn, events_and_contexts=events_and_contexts) - self._update_forward_extremities_txn( - txn, - new_forward_extremities=new_forward_extremities, - max_stream_order=max_stream_order, - ) + if new_forward_extremities: + self._update_forward_extremities_txn( + txn, + room_id, + new_forward_extremities=new_forward_extremities, + max_stream_order=max_stream_order, + ) self._persist_transaction_ids_txn(txn, events_and_contexts) @@ -464,7 +474,10 @@ class PersistEventsStore: # We call this last as it assumes we've inserted the events into # room_memberships, where applicable. # NB: This function invalidates all state related caches - self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + if state_delta_for_room: + self._update_current_state_txn( + txn, room_id, state_delta_for_room, min_stream_order + ) def _persist_event_auth_chain_txn( self, @@ -1026,74 +1039,75 @@ class PersistEventsStore: await self.db_pool.runInteraction( "update_current_state", self._update_current_state_txn, - state_delta_by_room={room_id: state_delta}, + room_id, + delta_state=state_delta, stream_id=stream_ordering, ) def _update_current_state_txn( self, txn: LoggingTransaction, - state_delta_by_room: Dict[str, DeltaState], + room_id: str, + delta_state: DeltaState, stream_id: int, ) -> None: - for room_id, delta_state in state_delta_by_room.items(): - to_delete = delta_state.to_delete - to_insert = delta_state.to_insert - - # Figure out the changes of membership to invalidate the - # `get_rooms_for_user` cache. - # We find out which membership events we may have deleted - # and which we have added, then we invalidate the caches for all - # those users. - members_changed = { - state_key - for ev_type, state_key in itertools.chain(to_delete, to_insert) - if ev_type == EventTypes.Member - } + to_delete = delta_state.to_delete + to_insert = delta_state.to_insert + + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invalidate the caches for all + # those users. + members_changed = { + state_key + for ev_type, state_key in itertools.chain(to_delete, to_insert) + if ev_type == EventTypes.Member + } - if delta_state.no_longer_in_room: - # Server is no longer in the room so we delete the room from - # current_state_events, being careful we've already updated the - # rooms.room_version column (which gets populated in a - # background task). - self._upsert_room_version_txn(txn, room_id) + if delta_state.no_longer_in_room: + # Server is no longer in the room so we delete the room from + # current_state_events, being careful we've already updated the + # rooms.room_version column (which gets populated in a + # background task). + self._upsert_room_version_txn(txn, room_id) - # Before deleting we populate the current_state_delta_stream - # so that async background tasks get told what happened. - sql = """ + # Before deleting we populate the current_state_delta_stream + # so that async background tasks get told what happened. + sql = """ INSERT INTO current_state_delta_stream (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) SELECT ?, ?, room_id, type, state_key, null, event_id FROM current_state_events WHERE room_id = ? """ - txn.execute(sql, (stream_id, self._instance_name, room_id)) + txn.execute(sql, (stream_id, self._instance_name, room_id)) - # We also want to invalidate the membership caches for users - # that were in the room. - users_in_room = self.store.get_users_in_room_txn(txn, room_id) - members_changed.update(users_in_room) + # We also want to invalidate the membership caches for users + # that were in the room. + users_in_room = self.store.get_users_in_room_txn(txn, room_id) + members_changed.update(users_in_room) - self.db_pool.simple_delete_txn( - txn, - table="current_state_events", - keyvalues={"room_id": room_id}, - ) - else: - # We're still in the room, so we update the current state as normal. + self.db_pool.simple_delete_txn( + txn, + table="current_state_events", + keyvalues={"room_id": room_id}, + ) + else: + # We're still in the room, so we update the current state as normal. - # First we add entries to the current_state_delta_stream. We - # do this before updating the current_state_events table so - # that we can use it to calculate the `prev_event_id`. (This - # allows us to not have to pull out the existing state - # unnecessarily). - # - # The stream_id for the update is chosen to be the minimum of the stream_ids - # for the batch of the events that we are persisting; that means we do not - # end up in a situation where workers see events before the - # current_state_delta updates. - # - sql = """ + # First we add entries to the current_state_delta_stream. We + # do this before updating the current_state_events table so + # that we can use it to calculate the `prev_event_id`. (This + # allows us to not have to pull out the existing state + # unnecessarily). + # + # The stream_id for the update is chosen to be the minimum of the stream_ids + # for the batch of the events that we are persisting; that means we do not + # end up in a situation where workers see events before the + # current_state_delta updates. + # + sql = """ INSERT INTO current_state_delta_stream (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) SELECT ?, ?, ?, ?, ?, ?, ( @@ -1101,39 +1115,39 @@ class PersistEventsStore: WHERE room_id = ? AND type = ? AND state_key = ? ) """ - txn.execute_batch( - sql, + txn.execute_batch( + sql, + ( ( - ( - stream_id, - self._instance_name, - room_id, - etype, - state_key, - to_insert.get((etype, state_key)), - room_id, - etype, - state_key, - ) - for etype, state_key in itertools.chain(to_delete, to_insert) - ), - ) - # Now we actually update the current_state_events table + stream_id, + self._instance_name, + room_id, + etype, + state_key, + to_insert.get((etype, state_key)), + room_id, + etype, + state_key, + ) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), + ) + # Now we actually update the current_state_events table - txn.execute_batch( - "DELETE FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?", - ( - (room_id, etype, state_key) - for etype, state_key in itertools.chain(to_delete, to_insert) - ), - ) + txn.execute_batch( + "DELETE FROM current_state_events" + " WHERE room_id = ? AND type = ? AND state_key = ?", + ( + (room_id, etype, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), + ) - # We include the membership in the current state table, hence we do - # a lookup when we insert. This assumes that all events have already - # been inserted into room_memberships. - txn.execute_batch( - """INSERT INTO current_state_events + # We include the membership in the current state table, hence we do + # a lookup when we insert. This assumes that all events have already + # been inserted into room_memberships. + txn.execute_batch( + """INSERT INTO current_state_events (room_id, type, state_key, event_id, membership, event_stream_ordering) VALUES ( ?, ?, ?, ?, @@ -1141,34 +1155,34 @@ class PersistEventsStore: (SELECT stream_ordering FROM events WHERE event_id = ?) ) """, - [ - (room_id, key[0], key[1], ev_id, ev_id, ev_id) - for key, ev_id in to_insert.items() - ], - ) + [ + (room_id, key[0], key[1], ev_id, ev_id, ev_id) + for key, ev_id in to_insert.items() + ], + ) - # We now update `local_current_membership`. We do this regardless - # of whether we're still in the room or not to handle the case where - # e.g. we just got banned (where we need to record that fact here). - - # Note: Do we really want to delete rows here (that we do not - # subsequently reinsert below)? While technically correct it means - # we have no record of the fact the user *was* a member of the - # room but got, say, state reset out of it. - if to_delete or to_insert: - txn.execute_batch( - "DELETE FROM local_current_membership" - " WHERE room_id = ? AND user_id = ?", - ( - (room_id, state_key) - for etype, state_key in itertools.chain(to_delete, to_insert) - if etype == EventTypes.Member and self.is_mine_id(state_key) - ), - ) + # We now update `local_current_membership`. We do this regardless + # of whether we're still in the room or not to handle the case where + # e.g. we just got banned (where we need to record that fact here). + + # Note: Do we really want to delete rows here (that we do not + # subsequently reinsert below)? While technically correct it means + # we have no record of the fact the user *was* a member of the + # room but got, say, state reset out of it. + if to_delete or to_insert: + txn.execute_batch( + "DELETE FROM local_current_membership" + " WHERE room_id = ? AND user_id = ?", + ( + (room_id, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + if etype == EventTypes.Member and self.is_mine_id(state_key) + ), + ) - if to_insert: - txn.execute_batch( - """INSERT INTO local_current_membership + if to_insert: + txn.execute_batch( + """INSERT INTO local_current_membership (room_id, user_id, event_id, membership, event_stream_ordering) VALUES ( ?, ?, ?, @@ -1176,29 +1190,27 @@ class PersistEventsStore: (SELECT stream_ordering FROM events WHERE event_id = ?) ) """, - [ - (room_id, key[1], ev_id, ev_id, ev_id) - for key, ev_id in to_insert.items() - if key[0] == EventTypes.Member and self.is_mine_id(key[1]) - ], - ) - - txn.call_after( - self.store._curr_state_delta_stream_cache.entity_has_changed, - room_id, - stream_id, + [ + (room_id, key[1], ev_id, ev_id, ev_id) + for key, ev_id in to_insert.items() + if key[0] == EventTypes.Member and self.is_mine_id(key[1]) + ], ) - # Invalidate the various caches - self.store._invalidate_state_caches_and_stream( - txn, room_id, members_changed - ) + txn.call_after( + self.store._curr_state_delta_stream_cache.entity_has_changed, + room_id, + stream_id, + ) - # Check if any of the remote membership changes requires us to - # unsubscribe from their device lists. - self.store.handle_potentially_left_users_txn( - txn, {m for m in members_changed if not self.hs.is_mine_id(m)} - ) + # Invalidate the various caches + self.store._invalidate_state_caches_and_stream(txn, room_id, members_changed) + + # Check if any of the remote membership changes requires us to + # unsubscribe from their device lists. + self.store.handle_potentially_left_users_txn( + txn, {m for m in members_changed if not self.hs.is_mine_id(m)} + ) def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None: """Update the room version in the database based off current state @@ -1232,23 +1244,19 @@ class PersistEventsStore: def _update_forward_extremities_txn( self, txn: LoggingTransaction, - new_forward_extremities: Dict[str, Set[str]], + room_id: str, + new_forward_extremities: Set[str], max_stream_order: int, ) -> None: - for room_id in new_forward_extremities.keys(): - self.db_pool.simple_delete_txn( - txn, table="event_forward_extremities", keyvalues={"room_id": room_id} - ) + self.db_pool.simple_delete_txn( + txn, table="event_forward_extremities", keyvalues={"room_id": room_id} + ) self.db_pool.simple_insert_many_txn( txn, table="event_forward_extremities", keys=("event_id", "room_id"), - values=[ - (ev_id, room_id) - for room_id, new_extrem in new_forward_extremities.items() - for ev_id in new_extrem - ], + values=[(ev_id, room_id) for ev_id in new_forward_extremities], ) # We now insert into stream_ordering_to_exterm a mapping from room_id, # new stream_ordering to new forward extremeties in the room. @@ -1260,8 +1268,7 @@ class PersistEventsStore: keys=("room_id", "event_id", "stream_ordering"), values=[ (room_id, event_id, max_stream_order) - for room_id, new_extrem in new_forward_extremities.items() - for event_id in new_extrem + for event_id in new_forward_extremities ], ) @@ -1298,36 +1305,45 @@ class PersistEventsStore: def _update_room_depths_txn( self, txn: LoggingTransaction, + room_id: str, events_and_contexts: List[Tuple[EventBase, EventContext]], ) -> None: """Update min_depth for each room Args: txn: db connection + room_id: The room ID events_and_contexts: events we are persisting """ - depth_updates: Dict[str, int] = {} + stream_ordering: Optional[int] = None + depth_update = 0 for event, context in events_and_contexts: - # Then update the `stream_ordering` position to mark the latest - # event as the front of the room. This should not be done for - # backfilled events because backfilled events have negative - # stream_ordering and happened in the past so we know that we don't - # need to update the stream_ordering tip/front for the room. + # Don't update the stream ordering for backfilled events because + # backfilled events have negative stream_ordering and happened in the + # past, so we know that we don't need to update the stream_ordering + # tip/front for the room. assert event.internal_metadata.stream_ordering is not None if event.internal_metadata.stream_ordering >= 0: - txn.call_after( - self.store._events_stream_cache.entity_has_changed, - event.room_id, - event.internal_metadata.stream_ordering, - ) + if stream_ordering is None: + stream_ordering = event.internal_metadata.stream_ordering + else: + stream_ordering = max( + stream_ordering, event.internal_metadata.stream_ordering + ) if not event.internal_metadata.is_outlier() and not context.rejected: - depth_updates[event.room_id] = max( - event.depth, depth_updates.get(event.room_id, event.depth) - ) + depth_update = max(event.depth, depth_update) - for room_id, depth in depth_updates.items(): - self._update_min_depth_for_room_txn(txn, room_id, depth) + # Then update the `stream_ordering` position to mark the latest event as + # the front of the room. + if stream_ordering is not None: + txn.call_after( + self.store._events_stream_cache.entity_has_changed, + room_id, + stream_ordering, + ) + + self._update_min_depth_for_room_txn(txn, room_id, depth_update) def _update_outliers_txn( self, @@ -1350,13 +1366,19 @@ class PersistEventsStore: PartialStateConflictError: if attempting to persist a partial state event in a room that has been un-partial stated. """ - txn.execute( - "SELECT event_id, outlier FROM events WHERE event_id in (%s)" - % (",".join(["?"] * len(events_and_contexts)),), - [event.event_id for event, _ in events_and_contexts], + rows = cast( + List[Tuple[str, bool]], + self.db_pool.simple_select_many_txn( + txn, + "events", + "event_id", + [event.event_id for event, _ in events_and_contexts], + keyvalues={}, + retcols=("event_id", "outlier"), + ), ) - have_persisted = dict(cast(Iterable[Tuple[str, bool]], txn)) + have_persisted = dict(rows) logger.debug( "_update_outliers_txn: events=%s have_persisted=%s", @@ -1454,7 +1476,7 @@ class PersistEventsStore: txn, table="event_json", keys=("event_id", "room_id", "internal_metadata", "json", "format_version"), - values=( + values=[ ( event.event_id, event.room_id, @@ -1463,7 +1485,7 @@ class PersistEventsStore: event.format_version, ) for event, _ in events_and_contexts - ), + ], ) self.db_pool.simple_insert_many_txn( @@ -1486,7 +1508,7 @@ class PersistEventsStore: "state_key", "rejection_reason", ), - values=( + values=[ ( self._instance_name, event.internal_metadata.stream_ordering, @@ -1505,7 +1527,7 @@ class PersistEventsStore: context.rejected, ) for event, context in events_and_contexts - ), + ], ) # If we're persisting an unredacted event we go and ensure @@ -1528,11 +1550,11 @@ class PersistEventsStore: txn, table="state_events", keys=("event_id", "room_id", "type", "state_key"), - values=( + values=[ (event.event_id, event.room_id, event.type, event.state_key) for event, _ in events_and_contexts if event.is_state() - ), + ], ) def _store_rejected_events_txn( @@ -1912,8 +1934,7 @@ class PersistEventsStore: if row is None: return - redacted_relates_to = row["relates_to_id"] - rel_type = row["relation_type"] + redacted_relates_to, rel_type = row self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 0061805150..0c91f19c8e 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -425,7 +425,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): """Background update to clean out extremities that should have been deleted previously. - Mainly used to deal with the aftermath of #5269. + Mainly used to deal with the aftermath of https://github.com/matrix-org/synapse/issues/5269. """ # This works by first copying all existing forward extremities into the @@ -558,7 +558,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) logger.info( - "Deleted %d forward extremities of %d checked, to clean up #5269", + "Deleted %d forward extremities of %d checked, to clean up matrix-org/synapse#5269", deleted, len(original_set), ) @@ -1222,14 +1222,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) # Iterate the parent IDs and invalidate caches. - for parent_id in {r[1] for r in relations_to_insert}: - cache_tuple = (parent_id,) - self._invalidate_cache_and_stream( # type: ignore[attr-defined] - txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined] - ) - self._invalidate_cache_and_stream( # type: ignore[attr-defined] - txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined] - ) + cache_tuples = {(r[1],) for r in relations_to_insert} + self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined] + txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined] + ) + self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined] + txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined] + ) if results: latest_event_id = results[-1][0] diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5bf864c1fb..4125059061 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -1312,7 +1312,8 @@ class EventsWorkerStore(SQLBaseStore): room_version: Optional[RoomVersion] if not room_version_id: # this should only happen for out-of-band membership events which - # arrived before #6983 landed. For all other events, we should have + # arrived before https://github.com/matrix-org/synapse/issues/6983 + # landed. For all other events, we should have # an entry in the 'rooms' table. # # However, the 'out_of_band_membership' flag is unreliable for older @@ -1323,7 +1324,8 @@ class EventsWorkerStore(SQLBaseStore): "Room %s for event %s is unknown" % (d["room_id"], event_id) ) - # so, assuming this is an out-of-band-invite that arrived before #6983 + # so, assuming this is an out-of-band-invite that arrived before + # https://github.com/matrix-org/synapse/issues/6983 # landed, we know that the room version must be v5 or earlier (because # v6 hadn't been invented at that point, so invites from such rooms # would have been rejected.) @@ -1998,7 +2000,7 @@ class EventsWorkerStore(SQLBaseStore): if not res: raise SynapseError(404, "Could not find event %s" % (event_id,)) - return int(res["topological_ordering"]), int(res["stream_ordering"]) + return int(res[0]), int(res[1]) async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: """Retrieve the entry with the lowest expiry timestamp in the event_expiry diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index ce88772f9e..c700872fdc 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py
@@ -107,13 +107,16 @@ class KeyStore(CacheInvalidationWorkerStore): # invalidate takes a tuple corresponding to the params of # _get_server_keys_json. _get_server_keys_json only takes one # param, which is itself the 2-tuple (server_name, key_id). - for key_id in verify_keys: - self._invalidate_cache_and_stream( - txn, self._get_server_keys_json, ((server_name, key_id),) - ) - self._invalidate_cache_and_stream( - txn, self.get_server_key_json_for_remote, (server_name, key_id) - ) + self._invalidate_cache_and_stream_bulk( + txn, + self._get_server_keys_json, + [((server_name, key_id),) for key_id in verify_keys], + ) + self._invalidate_cache_and_stream_bulk( + txn, + self.get_server_key_json_for_remote, + [(server_name, key_id) for key_id in verify_keys], + ) await self.db_pool.runInteraction( "store_server_keys_response", store_server_keys_response_txn diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index aeb3db596c..149135b8b5 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -15,9 +15,7 @@ from enum import Enum from typing import ( TYPE_CHECKING, - Any, Collection, - Dict, Iterable, List, Optional, @@ -26,6 +24,8 @@ from typing import ( cast, ) +import attr + from synapse.api.constants import Direction from synapse.logging.opentracing import trace from synapse.media._base import ThumbnailInfo @@ -45,6 +45,40 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( ) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class LocalMedia: + media_id: str + media_type: str + media_length: Optional[int] + upload_name: str + created_ts: int + url_cache: Optional[str] + last_access_ts: int + quarantined_by: Optional[str] + safe_from_quarantine: bool + user_id: Optional[str] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RemoteMedia: + media_origin: str + media_id: str + media_type: str + media_length: int + upload_name: Optional[str] + filesystem_id: str + created_ts: int + last_access_ts: int + quarantined_by: Optional[str] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UrlCache: + response_code: int + expires_ts: int + og: Union[str, bytes] + + class MediaSortOrder(Enum): """ Enum to define the sorting method used when returning media with @@ -116,6 +150,13 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): self._drop_media_index_without_method, ) + if hs.config.media.can_load_media_repo: + self.unused_expiration_time: Optional[ + int + ] = hs.config.media.unused_expiration_time + else: + self.unused_expiration_time = None + async def _drop_media_index_without_method( self, progress: JsonDict, batch_size: int ) -> int: @@ -151,13 +192,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): super().__init__(database, db_conn, hs) self.server_name: str = hs.hostname - async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: + async def get_local_media(self, media_id: str) -> Optional[LocalMedia]: """Get the metadata for a local piece of media Returns: None if the media_id doesn't exist. """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -167,11 +208,27 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "created_ts", "quarantined_by", "url_cache", + "last_access_ts", "safe_from_quarantine", + "user_id", ), allow_none=True, desc="get_local_media", ) + if row is None: + return None + return LocalMedia( + media_id=media_id, + media_type=row[0], + media_length=row[1], + upload_name=row[2], + created_ts=row[3], + quarantined_by=row[4], + url_cache=row[5], + last_access_ts=row[6], + safe_from_quarantine=row[7], + user_id=row[8], + ) async def get_local_media_by_user_paginate( self, @@ -180,7 +237,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id: str, order_by: str = MediaSortOrder.CREATED_TS.value, direction: Direction = Direction.FORWARDS, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -197,7 +254,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: # Set ordering order_by_column = MediaSortOrder(order_by).value @@ -217,14 +274,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): sql = """ SELECT - "media_id", - "media_type", - "media_length", - "upload_name", - "created_ts", - "last_access_ts", - "quarantined_by", - "safe_from_quarantine" + media_id, + media_type, + media_length, + upload_name, + created_ts, + url_cache, + last_access_ts, + quarantined_by, + safe_from_quarantine, + user_id FROM local_media_repository WHERE user_id = ? ORDER BY {order_by_column} {order}, media_id ASC @@ -236,7 +295,21 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): args += [limit, start] txn.execute(sql, args) - media = self.db_pool.cursor_to_dict(txn) + media = [ + LocalMedia( + media_id=row[0], + media_type=row[1], + media_length=row[2], + upload_name=row[3], + created_ts=row[4], + url_cache=row[5], + last_access_ts=row[6], + quarantined_by=row[7], + safe_from_quarantine=bool(row[8]), + user_id=row[9], + ) + for row in txn + ] return media, count return await self.db_pool.runInteraction( @@ -331,6 +404,23 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) @trace + async def store_local_media_id( + self, + media_id: str, + time_now_ms: int, + user_id: UserID, + ) -> None: + await self.db_pool.simple_insert( + "local_media_repository", + { + "media_id": media_id, + "created_ts": time_now_ms, + "user_id": user_id.to_string(), + }, + desc="store_local_media_id", + ) + + @trace async def store_local_media( self, media_id: str, @@ -355,6 +445,30 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_media", ) + async def update_local_media( + self, + media_id: str, + media_type: str, + upload_name: Optional[str], + media_length: int, + user_id: UserID, + url_cache: Optional[str] = None, + ) -> None: + await self.db_pool.simple_update_one( + "local_media_repository", + keyvalues={ + "user_id": user_id.to_string(), + "media_id": media_id, + }, + updatevalues={ + "media_type": media_type, + "upload_name": upload_name, + "media_length": media_length, + "url_cache": url_cache, + }, + desc="update_local_media", + ) + async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None: """Mark a local media as safe or unsafe from quarantining.""" await self.db_pool.simple_update_one( @@ -364,51 +478,72 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="mark_local_media_as_safe", ) - async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]: + async def count_pending_media(self, user_id: UserID) -> Tuple[int, int]: + """Count the number of pending media for a user. + + Returns: + A tuple of two integers: the total pending media requests and the earliest + expiration timestamp. + """ + + def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]: + sql = """ + SELECT COUNT(*), MIN(created_ts) + FROM local_media_repository + WHERE user_id = ? + AND created_ts > ? + AND media_length IS NULL + """ + assert self.unused_expiration_time is not None + txn.execute( + sql, + ( + user_id.to_string(), + self._clock.time_msec() - self.unused_expiration_time, + ), + ) + row = txn.fetchone() + if not row: + return 0, 0 + return row[0], (row[1] + self.unused_expiration_time if row[1] else 0) + + return await self.db_pool.runInteraction( + "get_pending_media", get_pending_media_txn + ) + + async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]: """Get the media_id and ts for a cached URL as of the given timestamp Returns: None if the URL isn't cached. """ - def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: + def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]: # get the most recently cached result (relative to the given ts) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts <= ?" - " ORDER BY download_ts DESC LIMIT 1" - ) + sql = """ + SELECT response_code, expires_ts, og + FROM local_media_repository_url_cache + WHERE url = ? AND download_ts <= ? + ORDER BY download_ts DESC LIMIT 1 + """ txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: # ...or if we've requested a timestamp older than the oldest # copy in the cache, return the oldest copy (if any) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts > ?" - " ORDER BY download_ts ASC LIMIT 1" - ) + sql = """ + SELECT response_code, expires_ts, og + FROM local_media_repository_url_cache + WHERE url = ? AND download_ts > ? + ORDER BY download_ts ASC LIMIT 1 + """ txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: return None - return dict( - zip( - ( - "response_code", - "etag", - "expires_ts", - "og", - "media_id", - "download_ts", - ), - row, - ) - ) + return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2]) return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) @@ -418,7 +553,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): response_code: int, etag: Optional[str], expires_ts: int, - og: Optional[str], + og: str, media_id: str, download_ts: int, ) -> None: @@ -484,8 +619,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_cached_remote_media( self, origin: str, media_id: str - ) -> Optional[Dict[str, Any]]: - return await self.db_pool.simple_select_one( + ) -> Optional[RemoteMedia]: + row = await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -494,11 +629,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "upload_name", "created_ts", "filesystem_id", + "last_access_ts", "quarantined_by", ), allow_none=True, desc="get_cached_remote_media", ) + if row is None: + return row + return RemoteMedia( + media_origin=origin, + media_id=media_id, + media_type=row[0], + media_length=row[1], + upload_name=row[2], + created_ts=row[3], + filesystem_id=row[4], + last_access_ts=row[5], + quarantined_by=row[6], + ) async def store_cached_remote_media( self, @@ -597,10 +746,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): t_width: int, t_height: int, t_type: str, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[ThumbnailInfo]: """Fetch the thumbnail info of given width, height and type.""" - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="remote_media_cache_thumbnails", keyvalues={ "media_origin": origin, @@ -615,11 +764,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "thumbnail_method", "thumbnail_type", "thumbnail_length", - "filesystem_id", ), allow_none=True, desc="get_remote_media_thumbnail", ) + if row is None: + return None + return ThumbnailInfo( + width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] + ) @trace async def store_remote_media_thumbnail( diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 4b1061e6d7..2911e53310 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -317,7 +317,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): if user_id: is_support = self.is_support_user_txn(txn, user_id) if not is_support: - # We do this manually here to avoid hitting #6791 + # We do this manually here to avoid hitting https://github.com/matrix-org/synapse/issues/6791 self.db_pool.simple_upsert_txn( txn, table="monthly_active_users", diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 3b444d2d07..0198bb09d2 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -363,10 +363,11 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) # for their user ID. value_values=[(presence_stream_id,) for _ in user_ids], ) - for user_id in user_ids: - self._invalidate_cache_and_stream( - txn, self._get_full_presence_stream_token_for_user, (user_id,) - ) + self._invalidate_cache_and_stream_bulk( + txn, + self._get_full_presence_stream_token_for_user, + [(user_id,) for user_id in user_ids], + ) return await self.db_pool.runInteraction( "add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 3ba9cc8853..7ed111f632 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -13,7 +13,6 @@ # limitations under the License. from typing import TYPE_CHECKING, Optional -from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore): return 50 async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: - try: - profile = await self.db_pool.simple_select_one( - table="profiles", - keyvalues={"full_user_id": user_id.to_string()}, - retcols=("displayname", "avatar_url"), - desc="get_profileinfo", - ) - except StoreError as e: - if e.code == 404: - # no match - return ProfileInfo(None, None) - else: - raise - - return ProfileInfo( - avatar_url=profile["avatar_url"], display_name=profile["displayname"] + profile = await self.db_pool.simple_select_one( + table="profiles", + keyvalues={"full_user_id": user_id.to_string()}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + allow_none=True, ) + if profile is None: + # no match + return ProfileInfo(None, None) + + return ProfileInfo(avatar_url=profile[1], display_name=profile[0]) async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 1e11bf2706..1a5b5731bb 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -295,19 +295,28 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # so make sure to keep this actually last. txn.execute("DROP TABLE events_to_purge") - for event_id, should_delete in event_rows: - self._invalidate_cache_and_stream( - txn, self._get_state_group_for_event, (event_id,) - ) + self._invalidate_cache_and_stream_bulk( + txn, + self._get_state_group_for_event, + [(event_id,) for event_id, _ in event_rows], + ) - # XXX: This is racy, since have_seen_events could be called between the - # transaction completing and the invalidation running. On the other hand, - # that's no different to calling `have_seen_events` just before the - # event is deleted from the database. + # XXX: This is racy, since have_seen_events could be called between the + # transaction completing and the invalidation running. On the other hand, + # that's no different to calling `have_seen_events` just before the + # event is deleted from the database. + self._invalidate_cache_and_stream_bulk( + txn, + self.have_seen_event, + [ + (room_id, event_id) + for event_id, should_delete in event_rows + if should_delete + ], + ) + + for event_id, should_delete in event_rows: if should_delete: - self._invalidate_cache_and_stream( - txn, self.have_seen_event, (room_id, event_id) - ) self.invalidate_get_event_cache_after_txn(txn, event_id) logger.info("[purge] done") @@ -485,7 +494,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # - room_tags_revisions # The problem with these is that they are largeish and there is no room_id # index on them. In any case we should be clearing out 'stream' tables - # periodically anyway (#5888) + # periodically anyway (https://github.com/matrix-org/synapse/issues/5888) self._invalidate_caches_for_room_and_stream(txn, room_id) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 22025eca56..cf622e195c 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -28,8 +28,11 @@ from typing import ( cast, ) +from twisted.internet import defer + from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -51,7 +54,8 @@ from synapse.storage.util.id_generators import ( ) from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules from synapse.types import JsonDict -from synapse.util import json_encoder +from synapse.util import json_encoder, unwrapFirstError +from synapse.util.async_helpers import gather_results from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -249,23 +253,33 @@ class PushRulesWorkerStore( user_id: [] for user_id in user_ids } - rows = cast( - List[Tuple[str, str, int, int, str, str]], - await self.db_pool.simple_select_many_batch( - table="push_rules", - column="user_name", - iterable=user_ids, - retcols=( - "user_name", - "rule_id", - "priority_class", - "priority", - "conditions", - "actions", + # gatherResults loses all type information. + rows, enabled_map_by_user = await make_deferred_yieldable( + gather_results( + ( + cast( + "defer.Deferred[List[Tuple[str, str, int, int, str, str]]]", + run_in_background( + self.db_pool.simple_select_many_batch, + table="push_rules", + column="user_name", + iterable=user_ids, + retcols=( + "user_name", + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ), + desc="bulk_get_push_rules", + batch_size=1000, + ), + ), + run_in_background(self.bulk_get_push_rules_enabled, user_ids), ), - desc="bulk_get_push_rules", - batch_size=1000, - ), + consumeErrors=True, + ).addErrback(unwrapFirstError) ) # Sort by highest priority_class, then highest priority. @@ -276,8 +290,6 @@ class PushRulesWorkerStore( (rule_id, priority_class, conditions, actions) ) - enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) - results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): @@ -437,27 +449,28 @@ class PushRuleStore(PushRulesWorkerStore): before: str, after: str, ) -> None: - # Lock the table since otherwise we'll have annoying races between the - # SELECT here and the UPSERT below. - self.database_engine.lock_table(txn, "push_rules") - relative_to_rule = before or after - res = self.db_pool.simple_select_one_txn( - txn, - table="push_rules", - keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, - retcols=["priority_class", "priority"], - allow_none=True, - ) + sql = """ + SELECT priority, priority_class FROM push_rules + WHERE user_name = ? AND rule_id = ? + """ - if not res: + if isinstance(self.database_engine, PostgresEngine): + sql += " FOR UPDATE" + else: + # Annoyingly SQLite doesn't support row level locking, so lock the whole table + self.database_engine.lock_table(txn, "push_rules") + + txn.execute(sql, (user_id, relative_to_rule)) + row = txn.fetchone() + + if row is None: raise RuleNotFoundException( "before/after rule not found: %s" % (relative_to_rule,) ) - base_priority_class = res["priority_class"] - base_rule_priority = res["priority"] + base_rule_priority, base_priority_class = row if base_priority_class != priority_class: raise InconsistentRuleException( @@ -505,9 +518,18 @@ class PushRuleStore(PushRulesWorkerStore): conditions_json: str, actions_json: str, ) -> None: - # Lock the table since otherwise we'll have annoying races between the - # SELECT here and the UPSERT below. - self.database_engine.lock_table(txn, "push_rules") + if isinstance(self.database_engine, PostgresEngine): + # Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first + # then re-select the count/max below. + sql = """ + SELECT * FROM push_rules + WHERE user_name = ? and priority_class = ? + FOR UPDATE + """ + txn.execute(sql, (user_id, priority_class)) + else: + # Annoyingly SQLite doesn't support row level locking, so lock the whole table + self.database_engine.lock_table(txn, "push_rules") # find the highest priority rule in that class sql = ( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 56e8eb16a8..3484ce9ef9 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore): allow_none=True, ) - stream_ordering = int(res["stream_ordering"]) if res else None - rx_ts = res["received_ts"] if res else 0 + stream_ordering = int(res[0]) if res else None + rx_ts = res[1] if res else 0 # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e09ab21593..2c3f30e2eb 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): account timestamp as milliseconds since the epoch. None if the account has not been renewed using the current token yet. """ - ret_dict = await self.db_pool.simple_select_one( - table="account_validity", - keyvalues={"renewal_token": renewal_token}, - retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], - desc="get_user_from_renewal_token", - ) - - return ( - ret_dict["user_id"], - ret_dict["expiration_ts_ms"], - ret_dict["token_used_ts_ms"], + return cast( + Tuple[str, int, Optional[int]], + await self.db_pool.simple_select_one( + table="account_validity", + keyvalues={"renewal_token": renewal_token}, + retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], + desc="get_user_from_renewal_token", + ), ) async def get_renewal_token_for_user(self, user_id: str) -> str: @@ -564,16 +561,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): updatevalues={"shadow_banned": shadow_banned}, ) # In order for this to apply immediately, clear the cache for this user. - tokens = self.db_pool.simple_select_onecol_txn( + tokens = self.db_pool.simple_select_list_txn( txn, table="access_tokens", keyvalues={"user_id": user_id}, - retcol="token", + retcols=("token",), + ) + self._invalidate_cache_and_stream_bulk( + txn, self.get_user_by_access_token, tokens ) - for token in tokens: - self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (token,) - ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) @@ -989,16 +985,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: user id, or None if no user id/threepid mapping exists """ - ret = self.db_pool.simple_select_one_txn( + return self.db_pool.simple_select_one_onecol_txn( txn, "user_threepids", {"medium": medium, "address": address}, - ["user_id"], + "user_id", True, ) - if ret: - return ret["user_id"] - return None async def user_add_threepid( self, @@ -1435,16 +1428,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): if res is None: return False + uses_allowed, pending, completed, expiry_time = res + # Check if the token has expired now = self._clock.time_msec() - if res["expiry_time"] and res["expiry_time"] < now: + if expiry_time and expiry_time < now: return False # Check if the token has been used up - if ( - res["uses_allowed"] - and res["pending"] + res["completed"] >= res["uses_allowed"] - ): + if uses_allowed and pending + completed >= uses_allowed: return False # Otherwise, the token is valid @@ -1490,8 +1482,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. - res = cast( - Dict[str, Any], + pending, completed = cast( + Tuple[int, int], self.db_pool.simple_select_one_txn( txn, "registration_tokens", @@ -1506,8 +1498,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "registration_tokens", keyvalues={"token": token}, updatevalues={ - "completed": res["completed"] + 1, - "pending": res["pending"] - 1, + "completed": completed + 1, + "pending": pending - 1, }, ) @@ -1517,7 +1509,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): async def get_registration_tokens( self, valid: Optional[bool] = None - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: """List all registration tokens. Used by the admin API. Args: @@ -1526,34 +1518,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Default is None: return all tokens regardless of validity. Returns: - A list of dicts, each containing details of a token. + A list of tuples containing: + * The token + * The number of users allowed (or None) + * Whether it is pending + * Whether it has been completed + * An expiry time (or None if no expiry) """ def select_registration_tokens_txn( txn: LoggingTransaction, now: int, valid: Optional[bool] - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: if valid is None: # Return all tokens regardless of validity - txn.execute("SELECT * FROM registration_tokens") + txn.execute( + """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + """ + ) elif valid: # Select valid tokens only - sql = ( - "SELECT * FROM registration_tokens WHERE " - "(uses_allowed > pending + completed OR uses_allowed IS NULL) " - "AND (expiry_time > ? OR expiry_time IS NULL)" - ) + sql = """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL) + AND (expiry_time > ? OR expiry_time IS NULL) + """ txn.execute(sql, [now]) else: # Select invalid tokens only - sql = ( - "SELECT * FROM registration_tokens WHERE " - "uses_allowed <= pending + completed OR expiry_time <= ?" - ) + sql = """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + WHERE uses_allowed <= pending + completed OR expiry_time <= ? + """ txn.execute(sql, [now]) - return self.db_pool.cursor_to_dict(txn) + return cast( + List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall() + ) return await self.db_pool.runInteraction( "select_registration_tokens", @@ -1571,13 +1577,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: A dict, or None if token doesn't exist. """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( "registration_tokens", keyvalues={"token": token}, retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], allow_none=True, desc="get_one_registration_token", ) + if row is None: + return None + return { + "token": row[0], + "uses_allowed": row[1], + "pending": row[2], + "completed": row[3], + "expiry_time": row[4], + } async def generate_registration_token( self, length: int, chars: str @@ -1700,7 +1715,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return None # Get all info about the token so it can be sent in the response - return self.db_pool.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, "registration_tokens", keyvalues={"token": token}, @@ -1714,6 +1729,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): allow_none=True, ) + if result is None: + return result + + return { + "token": result[0], + "uses_allowed": result[1], + "pending": result[2], + "completed": result[3], + "expiry_time": result[4], + } + return await self.db_pool.runInteraction( "update_registration_token", _update_registration_token_txn ) @@ -1925,11 +1951,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): keyvalues={"token": token}, updatevalues={"used_ts": ts}, ) - user_id = values["user_id"] - expiry_ts = values["expiry_ts"] - used_ts = values["used_ts"] - auth_provider_id = values["auth_provider_id"] - auth_provider_session_id = values["auth_provider_session_id"] + ( + user_id, + expiry_ts, + used_ts, + auth_provider_id, + auth_provider_session_id, + ) = values # Token was already used if used_ts is not None: @@ -2654,10 +2682,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): ) tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] - for token, _, _ in tokens_and_devices: - self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (token,) - ) + self._invalidate_cache_and_stream_bulk( + txn, + self.get_user_by_access_token, + [(token,) for token, _, _ in tokens_and_devices], + ) txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) @@ -2742,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): # reason, the next check is on the client secret, which is NOT NULL, # so we don't have to worry about the client secret matching by # accident. - row = {"client_secret": None, "validated_at": None} + row = None, None else: raise ThreepidValidationError("Unknown session_id") - retrieved_client_secret = row["client_secret"] - validated_at = row["validated_at"] + retrieved_client_secret, validated_at = row row = self.db_pool.simple_select_one_txn( txn, @@ -2761,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): raise ThreepidValidationError( "Validation token not found or has expired" ) - expires = row["expires"] - next_link = row["next_link"] + expires, next_link = row if retrieved_client_secret != client_secret: raise ThreepidValidationError( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3e8fcf1975..ef26d5d9d3 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -78,6 +78,31 @@ class RatelimitOverride: burst_count: int +@attr.s(slots=True, frozen=True, auto_attribs=True) +class LargestRoomStats: + room_id: str + name: Optional[str] + canonical_alias: Optional[str] + joined_members: int + join_rules: Optional[str] + guest_access: Optional[str] + history_visibility: Optional[str] + state_events: int + avatar: Optional[str] + topic: Optional[str] + room_type: Optional[str] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomStats(LargestRoomStats): + joined_local_members: int + version: Optional[str] + creator: Optional[str] + encryption: Optional[str] + federatable: bool + public: bool + + class RoomSortOrder(Enum): """ Enum to define the sorting method used when returning rooms with get_rooms_paginate @@ -188,23 +213,33 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: + async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]: """Retrieve a room. Args: room_id: The ID of the room to retrieve. Returns: - A dict containing the room information, or None if the room is unknown. + A tuple containing the room information: + * True if the room is public + * True if the room has an auth chain index + + or None if the room is unknown. """ - return await self.db_pool.simple_select_one( - table="rooms", - keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), - desc="get_room", - allow_none=True, + row = cast( + Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]], + await self.db_pool.simple_select_one( + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("is_public", "has_auth_chain_index"), + desc="get_room", + allow_none=True, + ), ) + if row is None: + return row + return bool(row[0]), bool(row[1]) - async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: + async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: """Retrieve room with statistics. Args: @@ -215,7 +250,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def get_room_with_stats_txn( txn: LoggingTransaction, room_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[RoomStats]: sql = """ SELECT room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, @@ -229,15 +264,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE room_id = ? """ txn.execute(sql, [room_id]) - # Catch error if sql returns empty result to return "None" instead of an error - try: - res = self.db_pool.cursor_to_dict(txn)[0] - except IndexError: + row = txn.fetchone() + if not row: return None - - res["federatable"] = bool(res["federatable"]) - res["public"] = bool(res["public"]) - return res + return RoomStats( + room_id=row[0], + name=row[1], + canonical_alias=row[2], + joined_members=row[3], + joined_local_members=row[4], + version=row[5], + creator=row[6], + encryption=row[7], + federatable=bool(row[8]), + public=bool(row[9]), + join_rules=row[10], + guest_access=row[11], + history_visibility=row[12], + state_events=row[13], + avatar=row[14], + topic=row[15], + room_type=row[16], + ) return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id @@ -368,7 +416,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): bounds: Optional[Tuple[int, str]], forwards: bool, ignore_non_federatable: bool = False, - ) -> List[Dict[str, Any]]: + ) -> List[LargestRoomStats]: """Gets the largest public rooms (where largest is in terms of joined members, as tracked in the statistics table). @@ -505,20 +553,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _get_largest_public_rooms_txn( txn: LoggingTransaction, - ) -> List[Dict[str, Any]]: + ) -> List[LargestRoomStats]: txn.execute(sql, query_args) - results = self.db_pool.cursor_to_dict(txn) + results = [ + LargestRoomStats( + room_id=r[0], + name=r[1], + canonical_alias=r[3], + joined_members=r[4], + join_rules=r[8], + guest_access=r[7], + history_visibility=r[6], + state_events=0, + avatar=r[5], + topic=r[2], + room_type=r[9], + ) + for r in txn + ] if not forwards: results.reverse() return results - ret_val = await self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) - return ret_val @cached(max_entries=10000) async def is_room_blocked(self, room_id: str) -> Optional[bool]: @@ -742,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) if row: - return RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - ) + return RatelimitOverride(messages_per_second=row[0], burst_count=row[1]) else: return None @@ -1319,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): join. """ - result = await self.db_pool.simple_select_one( - table="partial_state_rooms", - keyvalues={"room_id": room_id}, - retcols=("join_event_id", "device_lists_stream_id"), - desc="get_join_event_id_for_partial_state", + return cast( + Tuple[str, int], + await self.db_pool.simple_select_one( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + retcols=("join_event_id", "device_lists_stream_id"), + desc="get_join_event_id_for_partial_state", + ), ) - return result["join_event_id"], result["device_lists_stream_id"] def get_un_partial_stated_rooms_token(self, instance_name: str) -> int: return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer( @@ -2216,7 +2277,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): txn, table="partial_state_rooms_servers", keys=("room_id", "server_name"), - values=((room_id, s) for s in servers), + values=[(room_id, s) for s in servers], ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) self._invalidate_cache_and_stream( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 1ed7f2d0ef..60d4a9ef30 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): "non-local user %s" % (user_id,), ) - results_dict = await self.db_pool.simple_select_one( - "local_current_membership", - {"room_id": room_id, "user_id": user_id}, - ("membership", "event_id"), - allow_none=True, - desc="get_local_current_membership_for_user_in_room", + results = cast( + Optional[Tuple[str, str]], + await self.db_pool.simple_select_one( + "local_current_membership", + {"room_id": room_id, "user_id": user_id}, + ("membership", "event_id"), + allow_none=True, + desc="get_local_current_membership_for_user_in_room", + ), ) - if not results_dict: + if not results: return None, None - return results_dict.get("membership"), results_dict.get("event_id") + return results @cached(max_entries=500000, iterable=True) async def get_rooms_for_user_with_stream_ordering( diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index dbde9130c6..e25d86818b 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py
@@ -106,7 +106,7 @@ class SearchWorkerStore(SQLBaseStore): txn, table="event_search", keys=("event_id", "room_id", "key", "value"), - values=( + values=[ ( entry.event_id, entry.room_id, @@ -114,7 +114,7 @@ class SearchWorkerStore(SQLBaseStore): _clean_value_for_search(entry.value), ) for entry in entries - ), + ], ) else: @@ -275,7 +275,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): # we have to set autocommit, because postgres refuses to # CREATE INDEX CONCURRENTLY without it. - conn.set_session(autocommit=True) + conn.engine.attempt_to_set_autocommit(conn.conn, True) try: c = conn.cursor() @@ -301,7 +301,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): # we should now be able to delete the GIST index. c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist") finally: - conn.set_session(autocommit=False) + conn.engine.attempt_to_set_autocommit(conn.conn, False) if isinstance(self.database_engine, PostgresEngine): await self.db_pool.runWithConnection(create_index) @@ -323,7 +323,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): def create_index(conn: LoggingDatabaseConnection) -> None: conn.rollback() - conn.set_session(autocommit=True) + conn.engine.attempt_to_set_autocommit(conn.conn, True) c = conn.cursor() # We create with NULLS FIRST so that when we search *backwards* @@ -340,7 +340,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST) """ ) - conn.set_session(autocommit=False) + conn.engine.attempt_to_set_autocommit(conn.conn, False) await self.db_pool.runWithConnection(create_index) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 2225f8272d..563c275a2c 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_position_for_event", ) - return PersistedEventPosition( - row["instance_name"] or "master", row["stream_ordering"] - ) + return PersistedEventPosition(row[1] or "master", row[0]) async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: """The stream token for an event @@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) - return RoomStreamToken( - topological=row["topological_ordering"], stream=row["stream_ordering"] - ) + return RoomStreamToken(topological=row[1], stream=row[0]) async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: """Gets the topological token in a room after or at the given stream @@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = self.db_pool.simple_select_one_txn( - txn, - "events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcols=["stream_ordering", "topological_ordering"], + stream_ordering, topological_ordering = cast( + Tuple[int, int], + self.db_pool.simple_select_one_txn( + txn, + "events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcols=["stream_ordering", "topological_ordering"], + ), ) - # This cannot happen as `allow_none=False`. - assert results is not None - # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( - topological=results["topological_ordering"] - 1, - stream=results["stream_ordering"], + topological=topological_ordering - 1, stream=stream_ordering ) after_token = RoomStreamToken( - topological=results["topological_ordering"], - stream=results["stream_ordering"], + topological=topological_ordering, stream=stream_ordering ) rows, start_token = self._paginate_room_events_txn( diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
index 5555b53575..64543b4d61 100644 --- a/synapse/storage/databases/main/task_scheduler.py +++ b/synapse/storage/databases/main/task_scheduler.py
@@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore): Returns: the task if available, `None` otherwise """ - row = await self.db_pool.simple_select_one( - table="scheduled_tasks", - keyvalues={"id": id}, - retcols=( - "id", - "action", - "status", - "timestamp", - "resource_id", - "params", - "result", - "error", + row = cast( + Optional[ScheduledTaskRow], + await self.db_pool.simple_select_one( + table="scheduled_tasks", + keyvalues={"id": id}, + retcols=( + "id", + "action", + "status", + "timestamp", + "resource_id", + "params", + "result", + "error", + ), + allow_none=True, + desc="get_scheduled_task", ), - allow_none=True, - desc="get_scheduled_task", ) - return ( - TaskSchedulerWorkerStore._convert_row_to_task( - ( - row["id"], - row["action"], - row["status"], - row["timestamp"], - row["resource_id"], - row["params"], - row["result"], - row["error"], - ) - ) - if row - else None - ) + return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None async def delete_scheduled_task(self, id: str) -> None: """Delete a specific task from its id. diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index fecddb4144..2d341affaa 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, - retcols=( - "transaction_id", - "origin", - "ts", - "response_code", - "response_json", - "has_been_referenced", - ), + retcols=("response_code", "response_json"), allow_none=True, ) - if result and result["response_code"]: - return result["response_code"], db_to_json(result["response_json"]) + # If the result exists and the response code is non-0. + if result and result[0]: + return result[0], db_to_json(result[1]) else: return None @@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): # check we have a row and retry_last_ts is not null or zero # (retry_last_ts can't be negative) - if result and result["retry_last_ts"]: - return DestinationRetryTimings(**result) + if result and result[1]: + return DestinationRetryTimings( + failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2] + ) else: return None diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 8ab7c42c4a..5b164fed8e 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py
@@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore): desc="get_ui_auth_session", ) - result["clientdict"] = db_to_json(result["clientdict"]) - - return UIAuthSessionData(session_id, **result) + return UIAuthSessionData( + session_id, + clientdict=db_to_json(result[0]), + uri=result[1], + method=result[2], + description=result[3], + ) async def mark_ui_auth_stage_complete( self, @@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore): self, txn: LoggingTransaction, session_id: str, key: str, value: Any ) -> None: # Get the current value. - result = cast( - Dict[str, Any], - self.db_pool.simple_select_one_txn( - txn, - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - retcols=("serverdict",), - ), + result = self.db_pool.simple_select_one_onecol_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcol="serverdict", ) # Update it and add it back to the database. - serverdict = db_to_json(result["serverdict"]) + serverdict = db_to_json(result) serverdict[key] = value self.db_pool.simple_update_one_txn( @@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore): Raises: StoreError if the session cannot be found. """ - result = await self.db_pool.simple_select_one( + result = await self.db_pool.simple_select_one_onecol( table="ui_auth_sessions", keyvalues={"session_id": session_id}, - retcols=("serverdict",), + retcol="serverdict", desc="get_ui_auth_session_data", ) - serverdict = db_to_json(result["serverdict"]) + serverdict = db_to_json(result) return serverdict.get(key, default) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index a9f5d68b63..1a38f3d785 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -20,7 +20,6 @@ from typing import ( Collection, Iterable, List, - Mapping, Optional, Sequence, Set, @@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) - async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: - return await self.db_pool.simple_select_one( - table="user_directory", - keyvalues={"user_id": user_id}, - retcols=("display_name", "avatar_url"), - allow_none=True, - desc="get_user_in_directory", + async def _get_user_in_directory( + self, user_id: str + ) -> Optional[Tuple[Optional[str], Optional[str]]]: + """ + Fetch the user information in the user directory. + + Returns: + None if the user is unknown, otherwise a tuple of display name and + avatar URL (both of which may be None). + """ + return cast( + Optional[Tuple[Optional[str], Optional[str]]], + await self.db_pool.simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("display_name", "avatar_url"), + allow_none=True, + desc="get_user_in_directory", + ), ) async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None: diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 0f9c550b27..2c3151526d 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py
@@ -492,7 +492,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): conn.rollback() if isinstance(self.database_engine, PostgresEngine): # postgres insists on autocommit for the index - conn.set_session(autocommit=True) + conn.engine.attempt_to_set_autocommit(conn.conn, True) try: txn = conn.cursor() txn.execute( @@ -501,7 +501,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): ) txn.execute("DROP INDEX IF EXISTS state_groups_state_id") finally: - conn.set_session(autocommit=False) + conn.engine.attempt_to_set_autocommit(conn.conn, False) else: txn = conn.cursor() txn.execute( diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 6309363217..ec4c4041b7 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py
@@ -38,7 +38,8 @@ class PostgresEngine( super().__init__(psycopg2, database_config) psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) - # Disables passing `bytes` to txn.execute, c.f. #6186. If you do + # Disables passing `bytes` to txn.execute, c.f. + # https://github.com/matrix-org/synapse/issues/6186. If you do # actually want to use bytes than wrap it in `bytearray`. def _disable_bytes_adapter(_: bytes) -> NoReturn: raise Exception("Passing bytes to DB is disabled.") diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 158b528dce..03e5a0f55d 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py
@@ -109,7 +109,8 @@ Changes in SCHEMA_VERSION = 78 Changes in SCHEMA_VERSION = 79 - Add tables to handle in DB read-write locks. - - Add some mitigations for a painful race between foreground and background updates, cf #15677. + - Add some mitigations for a painful race between foreground and background updates, cf + https://github.com/matrix-org/synapse/issues/15677. Changes in SCHEMA_VERSION = 80 - The event_txn_id_device_id is always written to for new events. diff --git a/synapse/storage/schema/main/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/main/delta/54/delete_forward_extremities.sql
index b062ec840c..f713e42aa0 100644 --- a/synapse/storage/schema/main/delta/54/delete_forward_extremities.sql +++ b/synapse/storage/schema/main/delta/54/delete_forward_extremities.sql
@@ -14,7 +14,7 @@ */ -- Start a background job to cleanup extremities that were incorrectly added --- by bug #5269. +-- by bug https://github.com/matrix-org/synapse/issues/5269. INSERT INTO background_updates (update_name, progress_json) VALUES ('delete_soft_failed_extremities', '{}'); diff --git a/synapse/storage/schema/main/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/schema/main/delta/56/remove_tombstoned_rooms_from_directory.sql
index aeb17813d3..246c3359f7 100644 --- a/synapse/storage/schema/main/delta/56/remove_tombstoned_rooms_from_directory.sql +++ b/synapse/storage/schema/main/delta/56/remove_tombstoned_rooms_from_directory.sql
@@ -13,6 +13,7 @@ * limitations under the License. */ --- Now that #6232 is a thing, we can remove old rooms from the directory. +-- Now that https://github.com/matrix-org/synapse/pull/6232 is a thing, we can +-- remove old rooms from the directory. INSERT INTO background_updates (update_name, progress_json) VALUES ('remove_tombstoned_rooms_from_directory', '{}'); diff --git a/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql b/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql
index aed79635b2..31a61defa7 100644 --- a/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql +++ b/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql
@@ -13,7 +13,8 @@ * limitations under the License. */ --- Clean up left over rows from bug #11833, which was fixed in #12770. +-- Clean up left over rows from bug https://github.com/matrix-org/synapse/issues/11833, +-- which was fixed in https://github.com/matrix-org/synapse/pull/12770. DELETE FROM federation_inbound_events_staging WHERE room_id not in ( SELECT room_id FROM rooms ); diff --git a/synapse/storage/schema/main/delta/83/05_cross_signing_key_update_grant.sql b/synapse/storage/schema/main/delta/83/05_cross_signing_key_update_grant.sql new file mode 100644
index 0000000000..b74bdd71fa --- /dev/null +++ b/synapse/storage/schema/main/delta/83/05_cross_signing_key_update_grant.sql
@@ -0,0 +1,15 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +ALTER TABLE e2e_cross_signing_keys ADD COLUMN updatable_without_uia_before_ms bigint DEFAULT NULL; \ No newline at end of file diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9c3eafb562..bd3c81827f 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): next_id = self._load_next_id_txn(txn) - txn.call_after(self._mark_id_as_finished, next_id) - txn.call_on_exception(self._mark_id_as_finished, next_id) + txn.call_after(self._mark_ids_as_finished, [next_id]) + txn.call_on_exception(self._mark_ids_as_finished, [next_id]) txn.call_after(self._notifier.notify_replication) # Update the `stream_positions` table with newly updated stream @@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): return self._return_factor * next_id - def _mark_id_as_finished(self, next_id: int) -> None: - """The ID has finished being processed so we should advance the + def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]: + """ + Usage: + + stream_id = stream_id_gen.get_next_txn(txn) + # ... persist event ... + """ + + # If we have a list of instances that are allowed to write to this + # stream, make sure we're in it. + if self._writers and self._instance_name not in self._writers: + raise Exception("Tried to allocate stream ID on non-writer") + + next_ids = self._load_next_mult_id_txn(txn, n) + + txn.call_after(self._mark_ids_as_finished, next_ids) + txn.call_on_exception(self._mark_ids_as_finished, next_ids) + txn.call_after(self._notifier.notify_replication) + + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persisted row with the correct instance name. + if self._writers: + txn.call_after( + run_as_background_process, + "MultiWriterIdGenerator._update_table", + self._db.runInteraction, + "MultiWriterIdGenerator._update_table", + self._update_stream_positions_table_txn, + ) + + return [self._return_factor * next_id for next_id in next_ids] + + def _mark_ids_as_finished(self, next_ids: List[int]) -> None: + """These IDs have finished being processed so we should advance the current position if possible. """ with self._lock: - self._unfinished_ids.discard(next_id) - self._finished_ids.add(next_id) + self._unfinished_ids.difference_update(next_ids) + self._finished_ids.update(next_ids) new_cur: Optional[int] = None @@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): curr, new_cur, self._max_position_of_local_instance ) - self._add_persisted_position(next_id) + # TODO Can we call this for just the last position or somehow batch + # _add_persisted_position. + for next_id in next_ids: + self._add_persisted_position(next_id) def get_current_token(self) -> int: return self.get_persisted_upto_position() @@ -933,8 +972,7 @@ class _MultiWriterCtxManager: exc: Optional[BaseException], tb: Optional[TracebackType], ) -> bool: - for i in self.stream_ids: - self.id_gen._mark_id_as_finished(i) + self.id_gen._mark_ids_as_finished(self.stream_ids) self.notifier.notify_replication() diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 9f3b8741c1..8d9df352b2 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py
@@ -93,7 +93,7 @@ class Clock: _reactor: IReactorTime = attr.ib() - @defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations + @defer.inlineCallbacks def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]": d: defer.Deferred[float] = defer.Deferred() with context.PreserveLoggingContext(): diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 0cbeb0c365..8a55e4e41d 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -345,6 +345,7 @@ async def yieldable_gather_results_delaying_cancellation( T1 = TypeVar("T1") T2 = TypeVar("T2") T3 = TypeVar("T3") +T4 = TypeVar("T4") @overload @@ -380,6 +381,19 @@ def gather_results( ... +@overload +def gather_results( + deferredList: Tuple[ + "defer.Deferred[T1]", + "defer.Deferred[T2]", + "defer.Deferred[T3]", + "defer.Deferred[T4]", + ], + consumeErrors: bool = ..., +) -> "defer.Deferred[Tuple[T1, T2, T3, T4]]": + ... + + def gather_results( # type: ignore[misc] deferredList: Tuple["defer.Deferred[T1]", ...], consumeErrors: bool = False, diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
index f7cead9e12..6f008734a0 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py
@@ -189,7 +189,8 @@ def check_requirements(extra: Optional[str] = None) -> None: errors.append(_not_installed(requirement, extra)) else: if dist.version is None: - # This shouldn't happen---it suggests a borked virtualenv. (See #12223) + # This shouldn't happen---it suggests a borked virtualenv. (See + # https://github.com/matrix-org/synapse/issues/12223) # Try to give a vaguely helpful error message anyway. # Type-ignore: the annotations don't reflect reality: see # https://github.com/python/typeshed/issues/7513 diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index a0efb96d3b..f4c0194af0 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py
@@ -135,3 +135,54 @@ def sorted_topologically( degree_map[edge] -= 1 if degree_map[edge] == 0: heapq.heappush(zero_degree, edge) + + +def sorted_topologically_batched( + nodes: Iterable[T], + graph: Mapping[T, Collection[T]], +) -> Generator[Collection[T], None, None]: + r"""Walk the graph topologically, returning batches of nodes where all nodes + that references it have been previously returned. + + For example, given the following graph: + + A + / \ + B C + \ / + D + + This function will return: `[[A], [B, C], [D]]`. + + This function is useful for e.g. batch persisting events in an auth chain, + where we can only persist an event if all its auth events have already been + persisted. + """ + + degree_map = {node: 0 for node in nodes} + reverse_graph: Dict[T, Set[T]] = {} + + for node, edges in graph.items(): + if node not in degree_map: + continue + + for edge in set(edges): + if edge in degree_map: + degree_map[node] += 1 + + reverse_graph.setdefault(edge, set()).add(node) + reverse_graph.setdefault(node, set()) + + zero_degree = [node for node, degree in degree_map.items() if degree == 0] + + while zero_degree: + new_zero_degree = [] + for node in zero_degree: + for edge in reverse_graph.get(node, []): + if edge in degree_map: + degree_map[edge] -= 1 + if degree_map[edge] == 0: + new_zero_degree.append(edge) + + yield zero_degree + zero_degree = new_zero_degree diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index caf13b3474..8c2df233d3 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py
@@ -71,7 +71,7 @@ class TaskScheduler: # Time before a complete or failed task is deleted from the DB KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000 # 1 week # Maximum number of tasks that can run at the same time - MAX_CONCURRENT_RUNNING_TASKS = 10 + MAX_CONCURRENT_RUNNING_TASKS = 5 # Time from the last task update after which we will log a warning LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs @@ -193,7 +193,7 @@ class TaskScheduler: result: Optional[JsonMapping] = None, error: Optional[str] = None, ) -> bool: - """Update some task associated values. This is exposed publically so it can + """Update some task associated values. This is exposed publicly so it can be used inside task functions, mainly to update the result and be able to resume a task at a specific step after a restart of synapse.