diff options
56 files changed, 749 insertions, 460 deletions
diff --git a/changelog.d/16611.misc b/changelog.d/16611.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/16611.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/changelog.d/16612.misc b/changelog.d/16612.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/16612.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/changelog.d/16613.feature b/changelog.d/16613.feature new file mode 100644 index 0000000000..419c56fb83 --- /dev/null +++ b/changelog.d/16613.feature @@ -0,0 +1 @@ +Improve the performance of some operations in multi-worker deployments. diff --git a/changelog.d/16616.feature b/changelog.d/16616.feature new file mode 100644 index 0000000000..419c56fb83 --- /dev/null +++ b/changelog.d/16616.feature @@ -0,0 +1 @@ +Improve the performance of some operations in multi-worker deployments. diff --git a/changelog.d/16618.misc b/changelog.d/16618.misc new file mode 100644 index 0000000000..c026e6b995 --- /dev/null +++ b/changelog.d/16618.misc @@ -0,0 +1 @@ +Use `dbname` instead of the deprecated `database` connection parameter for psycopg2. diff --git a/docs/postgres.md b/docs/postgres.md index 02d4b9b162..ad7c6a0738 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -66,7 +66,7 @@ database: args: user: <user> password: <pass> - database: <db> + dbname: <db> host: <host> cp_min: 5 cp_max: 10 diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index a1ca5fa98c..a673975e04 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -1447,7 +1447,7 @@ database: args: user: synapse_user password: secretpassword - database: synapse + dbname: synapse host: localhost port: 5432 cp_min: 5 @@ -1526,7 +1526,7 @@ databases: args: user: synapse_user password: secretpassword - database: synapse_main + dbname: synapse_main host: localhost port: 5432 cp_min: 5 @@ -1539,7 +1539,7 @@ databases: args: user: synapse_user password: secretpassword - database: synapse_state + dbname: synapse_state host: localhost port: 5432 cp_min: 5 diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 8364ac8bc9..bb86419028 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -348,8 +348,7 @@ class Porter: backward_chunk = 0 already_ported = 0 else: - forward_chunk = row["forward_rowid"] - backward_chunk = row["backward_rowid"] + forward_chunk, backward_chunk = row if total_to_port is None: already_ported, total_to_port = await self._get_total_count_to_port( diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index c2109036ec..1027fbfd28 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from synapse.api.errors import ( AuthError, @@ -23,6 +23,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) +from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util.caches.descriptors import cached from synapse.util.stringutils import parse_and_validate_mxc_uri @@ -306,7 +307,9 @@ class ProfileHandler: server_name = host if self._is_mine_server_name(server_name): - media_info = await self.store.get_local_media(media_id) + media_info: Optional[ + Union[LocalMedia, RemoteMedia] + ] = await self.store.get_local_media(media_id) else: media_info = await self.store.get_cached_remote_media(server_name, media_id) @@ -322,12 +325,12 @@ class ProfileHandler: if self.max_avatar_size: # Ensure avatar does not exceed max allowed avatar size - if media_info["media_length"] > self.max_avatar_size: + if media_info.media_length > self.max_avatar_size: logger.warning( "Forbidding avatar change to %s: %d bytes is above the allowed size " "limit", mxc, - media_info["media_length"], + media_info.media_length, ) return False @@ -335,12 +338,12 @@ class ProfileHandler: # Ensure the avatar's file type is allowed if ( self.allowed_avatar_mimetypes - and media_info["media_type"] not in self.allowed_avatar_mimetypes + and media_info.media_type not in self.allowed_avatar_mimetypes ): logger.warning( "Forbidding avatar change to %s: mimetype %s not allowed", mxc, - media_info["media_type"], + media_info.media_type, ) return False diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6d680b0795..afd8138caf 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -269,7 +269,7 @@ class RoomCreationHandler: self, requester: Requester, old_room_id: str, - old_room: Dict[str, Any], + old_room: Tuple[bool, str, bool], new_room_id: str, new_version: RoomVersion, tombstone_event: EventBase, @@ -279,7 +279,7 @@ class RoomCreationHandler: Args: requester: the user requesting the upgrade old_room_id: the id of the room to be replaced - old_room: a dict containing room information for the room to be replaced, + old_room: a tuple containing room information for the room to be replaced, as returned by `RoomWorkerStore.get_room`. new_room_id: the id of the replacement room new_version: the version to upgrade the room to @@ -299,7 +299,7 @@ class RoomCreationHandler: await self.store.store_room( room_id=new_room_id, room_creator_user_id=user_id, - is_public=old_room["is_public"], + is_public=old_room[0], room_version=new_version, ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 918eb203e2..eddc2af9ba 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Add new room to the room directory if the old room was there # Remove old room from the room directory old_room = await self.store.get_room(old_room_id) - if old_room is not None and old_room["is_public"]: + # If the old room exists and is public. + if old_room is not None and old_room[0]: await self.store.set_room_is_public(old_room_id, False) await self.store.set_room_is_public(room_id, True) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 62f2454f5d..389dc5298a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -806,7 +806,7 @@ class SsoHandler: media_id = profile["avatar_url"].split("/")[-1] if self._is_mine_server_name(server_name): media = await self._media_repo.store.get_local_media(media_id) - if media is not None and upload_name == media["upload_name"]: + if media is not None and upload_name == media.upload_name: logger.info("skipping saving the user avatar") return True diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 72b0f1c5de..1957426c6a 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -19,6 +19,7 @@ import shutil from io import BytesIO from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple +import attr from matrix_common.types.mxc_uri import MXCUri import twisted.internet.error @@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.media_repository import RemoteMedia from synapse.types import UserID from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination @@ -245,18 +247,18 @@ class MediaRepository: Resolves once a response has successfully been written to request """ media_info = await self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: + if not media_info or media_info.quarantined_by: respond_404(request) return self.mark_recently_accessed(None, media_id) - media_type = media_info["media_type"] + media_type = media_info.media_type if not media_type: media_type = "application/octet-stream" - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] - url_cache = media_info["url_cache"] + media_length = media_info.media_length + upload_name = name if name else media_info.upload_name + url_cache = media_info.url_cache file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) @@ -310,16 +312,20 @@ class MediaRepository: # We deliberately stream the file outside the lock if responder: - media_type = media_info["media_type"] - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] + upload_name = name if name else media_info.upload_name await respond_with_responder( - request, responder, media_type, media_length, upload_name + request, + responder, + media_info.media_type, + media_info.media_length, + upload_name, ) else: respond_404(request) - async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: + async def get_remote_media_info( + self, server_name: str, media_id: str + ) -> RemoteMedia: """Gets the media info associated with the remote file, downloading if necessary. @@ -353,7 +359,7 @@ class MediaRepository: async def _get_remote_media_impl( self, server_name: str, media_id: str - ) -> Tuple[Optional[Responder], dict]: + ) -> Tuple[Optional[Responder], RemoteMedia]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -373,15 +379,17 @@ class MediaRepository: # If we have an entry in the DB, try and look for it if media_info: - file_id = media_info["filesystem_id"] + file_id = media_info.filesystem_id file_info = FileInfo(server_name, file_id) - if media_info["quarantined_by"]: + if media_info.quarantined_by: logger.info("Media is quarantined") raise NotFoundError() - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" + if not media_info.media_type: + media_info = attr.evolve( + media_info, media_type="application/octet-stream" + ) responder = await self.media_storage.fetch_media(file_info) if responder: @@ -403,9 +411,9 @@ class MediaRepository: if not media_info: raise e - file_id = media_info["filesystem_id"] - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" + file_id = media_info.filesystem_id + if not media_info.media_type: + media_info = attr.evolve(media_info, media_type="application/octet-stream") file_info = FileInfo(server_name, file_id) # We generate thumbnails even if another process downloaded the media @@ -415,7 +423,7 @@ class MediaRepository: # otherwise they'll request thumbnails and get a 404 if they're not # ready yet. await self._generate_thumbnails( - server_name, media_id, file_id, media_info["media_type"] + server_name, media_id, file_id, media_info.media_type ) responder = await self.media_storage.fetch_media(file_info) @@ -425,7 +433,7 @@ class MediaRepository: self, server_name: str, media_id: str, - ) -> dict: + ) -> RemoteMedia: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -518,7 +526,7 @@ class MediaRepository: origin=server_name, media_id=media_id, media_type=media_type, - time_now_ms=self.clock.time_msec(), + time_now_ms=time_now_ms, upload_name=upload_name, media_length=length, filesystem_id=file_id, @@ -526,15 +534,17 @@ class MediaRepository: logger.info("Stored remote media in file %r", fname) - media_info = { - "media_type": media_type, - "media_length": length, - "upload_name": upload_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - } - - return media_info + return RemoteMedia( + media_origin=server_name, + media_id=media_id, + media_type=media_type, + media_length=length, + upload_name=upload_name, + created_ts=time_now_ms, + filesystem_id=file_id, + last_access_ts=time_now_ms, + quarantined_by=None, + ) def _get_thumbnail_requirements( self, media_type: str diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py index 9b5a3dd5f4..44aac21de6 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py @@ -240,15 +240,14 @@ class UrlPreviewer: cache_result = await self.store.get_url_cache(url, ts) if ( cache_result - and cache_result["expires_ts"] > ts - and cache_result["response_code"] / 100 == 2 + and cache_result.expires_ts > ts + and cache_result.response_code // 100 == 2 ): # It may be stored as text in the database, not as bytes (such as # PostgreSQL). If so, encode it back before handing it on. - og = cache_result["og"] - if isinstance(og, str): - og = og.encode("utf8") - return og + if isinstance(cache_result.og, str): + return cache_result.og.encode("utf8") + return cache_result.og # If this URL can be accessed via an allowed oEmbed, use that instead. url_to_download = url diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 755c59274c..812144a128 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1860,7 +1860,8 @@ class PublicRoomListManager: if not room: return False - return room.get("is_public", False) + # The first item is whether the room is public. + return room[0] async def add_room_to_public_room_list(self, room_id: str) -> None: """Publishes a room to the public room list. diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 23a034522c..7e40bea8aa 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - ret = await self.store.get_room(room_id) - if not ret: + room = await self.store.get_room(room_id) + if not room: raise NotFoundError("Room not found") members = await self.store.get_users_in_room(room_id) @@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - ret = await self.store.get_room(room_id) - if not ret: + room = await self.store.get_room(room_id) + if not room: raise NotFoundError("Room not found") event_ids = await self._storage_controllers.state.get_current_state_ids(room_id) diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index 82944ca711..3534c3c259 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet): if room is None: raise NotFoundError("Unknown room") - return 200, {"visibility": "public" if room["is_public"] else "private"} + return 200, {"visibility": "public" if room[0] else "private"} class PutBody(RequestBodyModel): visibility: Literal["public", "private"] = "public" diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py index 85b6bdbe72..efda8b4ab4 100644 --- a/synapse/rest/media/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py @@ -119,7 +119,7 @@ class ThumbnailResource(RestServlet): if not media_info: respond_404(request) return - if media_info["quarantined_by"]: + if media_info.quarantined_by: logger.info("Media is quarantined") respond_404(request) return @@ -134,7 +134,7 @@ class ThumbnailResource(RestServlet): thumbnail_infos, media_id, media_id, - url_cache=bool(media_info["url_cache"]), + url_cache=bool(media_info.url_cache), server_name=None, ) @@ -152,7 +152,7 @@ class ThumbnailResource(RestServlet): if not media_info: respond_404(request) return - if media_info["quarantined_by"]: + if media_info.quarantined_by: logger.info("Media is quarantined") respond_404(request) return @@ -168,7 +168,7 @@ class ThumbnailResource(RestServlet): file_info = FileInfo( server_name=None, file_id=media_id, - url_cache=media_info["url_cache"], + url_cache=bool(media_info.url_cache), thumbnail=info, ) @@ -188,7 +188,7 @@ class ThumbnailResource(RestServlet): desired_height, desired_method, desired_type, - url_cache=bool(media_info["url_cache"]), + url_cache=bool(media_info.url_cache), ) if file_path: @@ -213,7 +213,7 @@ class ThumbnailResource(RestServlet): server_name, media_id ) - file_id = media_info["filesystem_id"] + file_id = media_info.filesystem_id for info in thumbnail_infos: t_w = info.width == desired_width @@ -224,7 +224,7 @@ class ThumbnailResource(RestServlet): if t_w and t_h and t_method and t_type: file_info = FileInfo( server_name=server_name, - file_id=media_info["filesystem_id"], + file_id=file_id, thumbnail=info, ) @@ -280,7 +280,7 @@ class ThumbnailResource(RestServlet): m_type, thumbnail_infos, media_id, - media_info["filesystem_id"], + media_info.filesystem_id, url_cache=False, server_name=server_name, ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 05775425b7..92b3c77b6d 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1660,7 +1660,7 @@ class DatabasePool: retcols: Collection[str], allow_none: Literal[False] = False, desc: str = "simple_select_one", - ) -> Dict[str, Any]: + ) -> Tuple[Any, ...]: ... @overload @@ -1671,7 +1671,7 @@ class DatabasePool: retcols: Collection[str], allow_none: Literal[True] = True, desc: str = "simple_select_one", - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: ... async def simple_select_one( @@ -1681,7 +1681,7 @@ class DatabasePool: retcols: Collection[str], allow_none: bool = False, desc: str = "simple_select_one", - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -2190,7 +2190,7 @@ class DatabasePool: keyvalues: Dict[str, Any], retcols: Collection[str], allow_none: bool = False, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) if keyvalues: @@ -2208,7 +2208,7 @@ class DatabasePool: if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - return dict(zip(retcols, row)) + return row async def simple_delete_one( self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index d7482a1f4e..07f9b65af3 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -747,8 +747,16 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) # Invalidate the cache for any ignored users which were added or removed. - for ignored_user_id in previously_ignored_users ^ currently_ignored_users: - self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) + self._invalidate_cache_and_stream_bulk( + txn, + self.ignored_by, + [ + (ignored_user_id,) + for ignored_user_id in ( + previously_ignored_users ^ currently_ignored_users + ) + ], + ) self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) async def remove_account_data_for_user( @@ -824,10 +832,14 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) # Invalidate the cache for ignored users which were removed. - for ignored_user_id in previously_ignored_users: - self._invalidate_cache_and_stream( - txn, self.ignored_by, (ignored_user_id,) - ) + self._invalidate_cache_and_stream_bulk( + txn, + self.ignored_by, + [ + (ignored_user_id,) + for ignored_user_id in previously_ignored_users + ], + ) # Invalidate for this user the cache tracking ignored users. self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 4d0470ffd9..d7232f566b 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + def _invalidate_cache_and_stream_bulk( + self, + txn: LoggingTransaction, + cache_func: CachedFunction, + key_tuples: Collection[Tuple[Any, ...]], + ) -> None: + """A bulk version of _invalidate_cache_and_stream. + + Locally invalidate every key-tuple in `key_tuples`, then emit invalidations + for each key-tuple over replication. + + This implementation is more efficient than a loop which repeatedly calls the + non-bulk version. + """ + if not key_tuples: + return + + for keys in key_tuples: + txn.call_after(cache_func.invalidate, keys) + + self._send_invalidation_to_replication_bulk( + txn, cache_func.__name__, key_tuples + ) + def _invalidate_all_cache_and_stream( self, txn: LoggingTransaction, cache_func: CachedFunction ) -> None: @@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if isinstance(self.database_engine, PostgresEngine): assert self._cache_id_gen is not None - # get_next() returns a context manager which is designed to wrap - # the transaction. However, we want to only get an ID when we want - # to use it, here, so we need to call __enter__ manually, and have - # __exit__ called after the transaction finishes. stream_id = self._cache_id_gen.get_next_txn(txn) txn.call_after(self.hs.get_notifier().on_new_replication_data) @@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore): }, ) + def _send_invalidation_to_replication_bulk( + self, + txn: LoggingTransaction, + cache_name: str, + key_tuples: Collection[Tuple[Any, ...]], + ) -> None: + """Announce the invalidation of multiple (but not all) cache entries. + + This is more efficient than repeated calls to the non-bulk version. It should + NOT be used to invalidating the entire cache: use + `_send_invalidation_to_replication` with keys=None. + + Note that this does *not* invalidate the cache locally. + + Args: + txn + cache_name + key_tuples: Key-tuples to invalidate. Assumed to be non-empty. + """ + if isinstance(self.database_engine, PostgresEngine): + assert self._cache_id_gen is not None + + stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples)) + ts = self._clock.time_msec() + txn.call_after(self.hs.get_notifier().on_new_replication_data) + self.db_pool.simple_insert_many_txn( + txn, + table="cache_invalidation_stream_by_instance", + keys=( + "stream_id", + "instance_name", + "cache_func", + "keys", + "invalidation_ts", + ), + values=[ + # We convert key_tuples to a list here because psycopg2 serialises + # lists as pq arrrays, but serialises tuples as "composite types". + # (We need an array because the `keys` column has type `[]text`.) + # See: + # https://www.psycopg.org/docs/usage.html#adapt-list + # https://www.psycopg.org/docs/usage.html#adapt-tuple + (stream_id, self._instance_name, cache_name, list(key_tuple), ts) + for stream_id, key_tuple in zip(stream_ids, key_tuples) + ], + ) + def get_cache_stream_token_for_writer(self, instance_name: str) -> int: if self._cache_id_gen: return self._cache_id_gen.get_current_token_for_writer(instance_name) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 04d12a876c..775abbac79 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): A dict containing the device information, or `None` if the device does not exist. """ - return await self.db_pool.simple_select_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, - retcols=("user_id", "device_id", "display_name"), - desc="get_device", - allow_none=True, - ) - - async def get_device_opt( - self, user_id: str, device_id: str - ) -> Optional[Dict[str, Any]]: - """Retrieve a device. Only returns devices that are not marked as - hidden. - - Args: - user_id: The ID of the user which owns the device - device_id: The ID of the device to retrieve - Returns: - A dict containing the device information, or None if the device does not exist. - """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", allow_none=True, ) + if row is None: + return None + return {"user_id": row[0], "device_id": row[1], "display_name": row[2]} async def get_devices_by_user( self, user_id: str @@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): retcols=["device_id", "device_data"], allow_none=True, ) - return ( - (row["device_id"], json_decoder.decode(row["device_data"])) if row else None - ) + return (row[0], json_decoder.decode(row[1])) if row else None def _store_dehydrated_device_txn( self, @@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): `FALSE` have not been converted. """ - row = await self.db_pool.simple_select_one( - table="device_lists_changes_converted_stream_position", - keyvalues={}, - retcols=["stream_id", "room_id"], - desc="get_device_change_last_converted_pos", + return cast( + Tuple[int, str], + await self.db_pool.simple_select_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + retcols=["stream_id", "room_id"], + desc="get_device_change_last_converted_pos", + ), ) - return row["stream_id"], row["room_id"] async def set_device_change_last_converted_pos( self, diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index ad904a26a6..fae23c3407 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): # it isn't there. raise StoreError(404, "No backup with that version exists") - result = self.db_pool.simple_select_one_txn( - txn, - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data", "etag"), - allow_none=False, + row = cast( + Tuple[int, str, str, Optional[int]], + self.db_pool.simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + "deleted": 0, + }, + retcols=("version", "algorithm", "auth_data", "etag"), + allow_none=False, + ), ) - assert result is not None # see comment on `simple_select_one_txn` - result["auth_data"] = db_to_json(result["auth_data"]) - result["version"] = str(result["version"]) - if result["etag"] is None: - result["etag"] = 0 - return result + return { + "auth_data": db_to_json(row[2]), + "version": str(row[0]), + "algorithm": row[1], + "etag": 0 if row[3] is None else row[3], + } return await self.db_pool.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 4f96ac25c7..8cb61eaee3 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker for user_id, device_id, algorithm, key_id, key_json in claimed_keys: device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) - - if (user_id, device_id) in seen_user_device: - continue seen_user_device.add((user_id, device_id)) - self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) - ) + + self._invalidate_cache_and_stream_bulk( + txn, self.get_e2e_unused_fallback_key_types, seen_user_device + ) return results @@ -1268,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if row is None: continue - key_id = row["key_id"] - key_json = row["key_json"] - used = row["used"] + key_id, key_json, used = row # Mark fallback key as used if not already. if not used and mark_as_used: @@ -1376,14 +1372,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list) ) - seen_user_device: Set[Tuple[str, str]] = set() - for user_id, device_id, _, _, _ in otk_rows: - if (user_id, device_id) in seen_user_device: - continue - seen_user_device.add((user_id, device_id)) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) + seen_user_device = { + (user_id, device_id) for user_id, device_id, _, _, _ in otk_rows + } + self._invalidate_cache_and_stream_bulk( + txn, + self.count_e2e_one_time_keys, + seen_user_device, + ) return otk_rows diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index b8bbd1eccd..98556a0523 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Check if we have indexed the room so we can use the chain cover # algorithm. room = await self.get_room(room_id) # type: ignore[attr-defined] - if room["has_auth_chain_index"]: + # If the room has an auth chain index. + if room[1]: try: return await self.db_pool.runInteraction( "get_auth_chain_ids_chains", @@ -410,7 +411,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Check if we have indexed the room so we can use the chain cover # algorithm. room = await self.get_room(room_id) # type: ignore[attr-defined] - if room["has_auth_chain_index"]: + # If the room has an auth chain index. + if room[1]: try: return await self.db_pool.runInteraction( "get_auth_chain_difference_chains", @@ -1436,24 +1438,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) if event_lookup_result is not None: + event_type, depth, stream_ordering = event_lookup_result logger.debug( "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s", room_id, seed_event_id, - event_lookup_result["depth"], - event_lookup_result["stream_ordering"], - event_lookup_result["type"], + depth, + stream_ordering, + event_type, ) - if event_lookup_result["depth"]: - queue.put( - ( - -event_lookup_result["depth"], - -event_lookup_result["stream_ordering"], - seed_event_id, - event_lookup_result["type"], - ) - ) + if depth: + queue.put((-depth, -stream_ordering, seed_event_id, event_type)) while not queue.empty() and len(event_id_results) < limit: try: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 7c34bde3e5..5207cc0f4e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1934,8 +1934,7 @@ class PersistEventsStore: if row is None: return - redacted_relates_to = row["relates_to_id"] - rel_type = row["relation_type"] + redacted_relates_to, rel_type = row self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 0061805150..9c46c5d7bd 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1222,14 +1222,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) # Iterate the parent IDs and invalidate caches. - for parent_id in {r[1] for r in relations_to_insert}: - cache_tuple = (parent_id,) - self._invalidate_cache_and_stream( # type: ignore[attr-defined] - txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined] - ) - self._invalidate_cache_and_stream( # type: ignore[attr-defined] - txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined] - ) + cache_tuples = {(r[1],) for r in relations_to_insert} + self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined] + txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined] + ) + self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined] + txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined] + ) if results: latest_event_id = results[-1][0] diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 5bf864c1fb..4e63a16fa2 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore): if not res: raise SynapseError(404, "Could not find event %s" % (event_id,)) - return int(res["topological_ordering"]), int(res["stream_ordering"]) + return int(res[0]), int(res[1]) async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: """Retrieve the entry with the lowest expiry timestamp in the event_expiry diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index ce88772f9e..c700872fdc 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -107,13 +107,16 @@ class KeyStore(CacheInvalidationWorkerStore): # invalidate takes a tuple corresponding to the params of # _get_server_keys_json. _get_server_keys_json only takes one # param, which is itself the 2-tuple (server_name, key_id). - for key_id in verify_keys: - self._invalidate_cache_and_stream( - txn, self._get_server_keys_json, ((server_name, key_id),) - ) - self._invalidate_cache_and_stream( - txn, self.get_server_key_json_for_remote, (server_name, key_id) - ) + self._invalidate_cache_and_stream_bulk( + txn, + self._get_server_keys_json, + [((server_name, key_id),) for key_id in verify_keys], + ) + self._invalidate_cache_and_stream_bulk( + txn, + self.get_server_key_json_for_remote, + [(server_name, key_id) for key_id in verify_keys], + ) await self.db_pool.runInteraction( "store_server_keys_response", store_server_keys_response_txn diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index c8d7c9fd32..3f80a64dc5 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -15,9 +15,7 @@ from enum import Enum from typing import ( TYPE_CHECKING, - Any, Collection, - Dict, Iterable, List, Optional, @@ -54,11 +52,32 @@ class LocalMedia: media_length: int upload_name: str created_ts: int + url_cache: Optional[str] last_access_ts: int quarantined_by: Optional[str] safe_from_quarantine: bool +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RemoteMedia: + media_origin: str + media_id: str + media_type: str + media_length: int + upload_name: Optional[str] + filesystem_id: str + created_ts: int + last_access_ts: int + quarantined_by: Optional[str] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UrlCache: + response_code: int + expires_ts: int + og: Union[str, bytes] + + class MediaSortOrder(Enum): """ Enum to define the sorting method used when returning media with @@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): super().__init__(database, db_conn, hs) self.server_name: str = hs.hostname - async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: + async def get_local_media(self, media_id: str) -> Optional[LocalMedia]: """Get the metadata for a local piece of media Returns: None if the media_id doesn't exist. """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -181,11 +200,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "created_ts", "quarantined_by", "url_cache", + "last_access_ts", "safe_from_quarantine", ), allow_none=True, desc="get_local_media", ) + if row is None: + return None + return LocalMedia( + media_id=media_id, + media_type=row[0], + media_length=row[1], + upload_name=row[2], + created_ts=row[3], + quarantined_by=row[4], + url_cache=row[5], + last_access_ts=row[6], + safe_from_quarantine=row[7], + ) async def get_local_media_by_user_paginate( self, @@ -236,6 +269,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length, upload_name, created_ts, + url_cache, last_access_ts, quarantined_by, safe_from_quarantine @@ -257,9 +291,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length=row[2], upload_name=row[3], created_ts=row[4], - last_access_ts=row[5], - quarantined_by=row[6], - safe_from_quarantine=bool(row[7]), + url_cache=row[5], + last_access_ts=row[6], + quarantined_by=row[7], + safe_from_quarantine=bool(row[8]), ) for row in txn ] @@ -390,51 +425,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="mark_local_media_as_safe", ) - async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]: + async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]: """Get the media_id and ts for a cached URL as of the given timestamp Returns: None if the URL isn't cached. """ - def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: + def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]: # get the most recently cached result (relative to the given ts) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts <= ?" - " ORDER BY download_ts DESC LIMIT 1" - ) + sql = """ + SELECT response_code, expires_ts, og + FROM local_media_repository_url_cache + WHERE url = ? AND download_ts <= ? + ORDER BY download_ts DESC LIMIT 1 + """ txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: # ...or if we've requested a timestamp older than the oldest # copy in the cache, return the oldest copy (if any) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts > ?" - " ORDER BY download_ts ASC LIMIT 1" - ) + sql = """ + SELECT response_code, expires_ts, og + FROM local_media_repository_url_cache + WHERE url = ? AND download_ts > ? + ORDER BY download_ts ASC LIMIT 1 + """ txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: return None - return dict( - zip( - ( - "response_code", - "etag", - "expires_ts", - "og", - "media_id", - "download_ts", - ), - row, - ) - ) + return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2]) return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) @@ -444,7 +467,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): response_code: int, etag: Optional[str], expires_ts: int, - og: Optional[str], + og: str, media_id: str, download_ts: int, ) -> None: @@ -510,8 +533,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_cached_remote_media( self, origin: str, media_id: str - ) -> Optional[Dict[str, Any]]: - return await self.db_pool.simple_select_one( + ) -> Optional[RemoteMedia]: + row = await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -520,11 +543,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "upload_name", "created_ts", "filesystem_id", + "last_access_ts", "quarantined_by", ), allow_none=True, desc="get_cached_remote_media", ) + if row is None: + return row + return RemoteMedia( + media_origin=origin, + media_id=media_id, + media_type=row[0], + media_length=row[1], + upload_name=row[2], + created_ts=row[3], + filesystem_id=row[4], + last_access_ts=row[5], + quarantined_by=row[6], + ) async def store_cached_remote_media( self, @@ -623,10 +660,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): t_width: int, t_height: int, t_type: str, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[ThumbnailInfo]: """Fetch the thumbnail info of given width, height and type.""" - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="remote_media_cache_thumbnails", keyvalues={ "media_origin": origin, @@ -641,11 +678,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "thumbnail_method", "thumbnail_type", "thumbnail_length", - "filesystem_id", ), allow_none=True, desc="get_remote_media_thumbnail", ) + if row is None: + return None + return ThumbnailInfo( + width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] + ) @trace async def store_remote_media_thumbnail( diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 3b444d2d07..0198bb09d2 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -363,10 +363,11 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) # for their user ID. value_values=[(presence_stream_id,) for _ in user_ids], ) - for user_id in user_ids: - self._invalidate_cache_and_stream( - txn, self._get_full_presence_stream_token_for_user, (user_id,) - ) + self._invalidate_cache_and_stream_bulk( + txn, + self._get_full_presence_stream_token_for_user, + [(user_id,) for user_id in user_ids], + ) return await self.db_pool.runInteraction( "add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 3ba9cc8853..7ed111f632 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import TYPE_CHECKING, Optional -from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore): return 50 async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: - try: - profile = await self.db_pool.simple_select_one( - table="profiles", - keyvalues={"full_user_id": user_id.to_string()}, - retcols=("displayname", "avatar_url"), - desc="get_profileinfo", - ) - except StoreError as e: - if e.code == 404: - # no match - return ProfileInfo(None, None) - else: - raise - - return ProfileInfo( - avatar_url=profile["avatar_url"], display_name=profile["displayname"] + profile = await self.db_pool.simple_select_one( + table="profiles", + keyvalues={"full_user_id": user_id.to_string()}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + allow_none=True, ) + if profile is None: + # no match + return ProfileInfo(None, None) + + return ProfileInfo(avatar_url=profile[1], display_name=profile[0]) async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 1e11bf2706..c3b3e2baaf 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -295,19 +295,28 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # so make sure to keep this actually last. txn.execute("DROP TABLE events_to_purge") - for event_id, should_delete in event_rows: - self._invalidate_cache_and_stream( - txn, self._get_state_group_for_event, (event_id,) - ) + self._invalidate_cache_and_stream_bulk( + txn, + self._get_state_group_for_event, + [(event_id,) for event_id, _ in event_rows], + ) - # XXX: This is racy, since have_seen_events could be called between the - # transaction completing and the invalidation running. On the other hand, - # that's no different to calling `have_seen_events` just before the - # event is deleted from the database. + # XXX: This is racy, since have_seen_events could be called between the + # transaction completing and the invalidation running. On the other hand, + # that's no different to calling `have_seen_events` just before the + # event is deleted from the database. + self._invalidate_cache_and_stream_bulk( + txn, + self.have_seen_event, + [ + (room_id, event_id) + for event_id, should_delete in event_rows + if should_delete + ], + ) + + for event_id, should_delete in event_rows: if should_delete: - self._invalidate_cache_and_stream( - txn, self.have_seen_event, (room_id, event_id) - ) self.invalidate_get_event_cache_after_txn(txn, event_id) logger.info("[purge] done") diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 37135d431d..f72a23c584 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore): "before/after rule not found: %s" % (relative_to_rule,) ) - base_priority_class = res["priority_class"] - base_rule_priority = res["priority"] + base_priority_class, base_rule_priority = res if base_priority_class != priority_class: raise InconsistentRuleException( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 56e8eb16a8..3484ce9ef9 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore): allow_none=True, ) - stream_ordering = int(res["stream_ordering"]) if res else None - rx_ts = res["received_ts"] if res else 0 + stream_ordering = int(res[0]) if res else None + rx_ts = res[1] if res else 0 # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 933d76e905..2c3f30e2eb 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): account timestamp as milliseconds since the epoch. None if the account has not been renewed using the current token yet. """ - ret_dict = await self.db_pool.simple_select_one( - table="account_validity", - keyvalues={"renewal_token": renewal_token}, - retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], - desc="get_user_from_renewal_token", - ) - - return ( - ret_dict["user_id"], - ret_dict["expiration_ts_ms"], - ret_dict["token_used_ts_ms"], + return cast( + Tuple[str, int, Optional[int]], + await self.db_pool.simple_select_one( + table="account_validity", + keyvalues={"renewal_token": renewal_token}, + retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], + desc="get_user_from_renewal_token", + ), ) async def get_renewal_token_for_user(self, user_id: str) -> str: @@ -564,16 +561,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): updatevalues={"shadow_banned": shadow_banned}, ) # In order for this to apply immediately, clear the cache for this user. - tokens = self.db_pool.simple_select_onecol_txn( + tokens = self.db_pool.simple_select_list_txn( txn, table="access_tokens", keyvalues={"user_id": user_id}, - retcol="token", + retcols=("token",), + ) + self._invalidate_cache_and_stream_bulk( + txn, self.get_user_by_access_token, tokens ) - for token in tokens: - self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (token,) - ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) @@ -989,16 +985,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: user id, or None if no user id/threepid mapping exists """ - ret = self.db_pool.simple_select_one_txn( + return self.db_pool.simple_select_one_onecol_txn( txn, "user_threepids", {"medium": medium, "address": address}, - ["user_id"], + "user_id", True, ) - if ret: - return ret["user_id"] - return None async def user_add_threepid( self, @@ -1435,16 +1428,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): if res is None: return False + uses_allowed, pending, completed, expiry_time = res + # Check if the token has expired now = self._clock.time_msec() - if res["expiry_time"] and res["expiry_time"] < now: + if expiry_time and expiry_time < now: return False # Check if the token has been used up - if ( - res["uses_allowed"] - and res["pending"] + res["completed"] >= res["uses_allowed"] - ): + if uses_allowed and pending + completed >= uses_allowed: return False # Otherwise, the token is valid @@ -1490,8 +1482,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. - res = cast( - Dict[str, Any], + pending, completed = cast( + Tuple[int, int], self.db_pool.simple_select_one_txn( txn, "registration_tokens", @@ -1506,8 +1498,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "registration_tokens", keyvalues={"token": token}, updatevalues={ - "completed": res["completed"] + 1, - "pending": res["pending"] - 1, + "completed": completed + 1, + "pending": pending - 1, }, ) @@ -1585,13 +1577,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: A dict, or None if token doesn't exist. """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( "registration_tokens", keyvalues={"token": token}, retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], allow_none=True, desc="get_one_registration_token", ) + if row is None: + return None + return { + "token": row[0], + "uses_allowed": row[1], + "pending": row[2], + "completed": row[3], + "expiry_time": row[4], + } async def generate_registration_token( self, length: int, chars: str @@ -1714,7 +1715,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return None # Get all info about the token so it can be sent in the response - return self.db_pool.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, "registration_tokens", keyvalues={"token": token}, @@ -1728,6 +1729,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): allow_none=True, ) + if result is None: + return result + + return { + "token": result[0], + "uses_allowed": result[1], + "pending": result[2], + "completed": result[3], + "expiry_time": result[4], + } + return await self.db_pool.runInteraction( "update_registration_token", _update_registration_token_txn ) @@ -1939,11 +1951,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): keyvalues={"token": token}, updatevalues={"used_ts": ts}, ) - user_id = values["user_id"] - expiry_ts = values["expiry_ts"] - used_ts = values["used_ts"] - auth_provider_id = values["auth_provider_id"] - auth_provider_session_id = values["auth_provider_session_id"] + ( + user_id, + expiry_ts, + used_ts, + auth_provider_id, + auth_provider_session_id, + ) = values # Token was already used if used_ts is not None: @@ -2668,10 +2682,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): ) tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] - for token, _, _ in tokens_and_devices: - self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (token,) - ) + self._invalidate_cache_and_stream_bulk( + txn, + self.get_user_by_access_token, + [(token,) for token, _, _ in tokens_and_devices], + ) txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) @@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): # reason, the next check is on the client secret, which is NOT NULL, # so we don't have to worry about the client secret matching by # accident. - row = {"client_secret": None, "validated_at": None} + row = None, None else: raise ThreepidValidationError("Unknown session_id") - retrieved_client_secret = row["client_secret"] - validated_at = row["validated_at"] + retrieved_client_secret, validated_at = row row = self.db_pool.simple_select_one_txn( txn, @@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): raise ThreepidValidationError( "Validation token not found or has expired" ) - expires = row["expires"] - next_link = row["next_link"] + expires, next_link = row if retrieved_client_secret != client_secret: raise ThreepidValidationError( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index afb880532e..ef26d5d9d3 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: + async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]: """Retrieve a room. Args: room_id: The ID of the room to retrieve. Returns: - A dict containing the room information, or None if the room is unknown. + A tuple containing the room information: + * True if the room is public + * True if the room has an auth chain index + + or None if the room is unknown. """ - return await self.db_pool.simple_select_one( - table="rooms", - keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), - desc="get_room", - allow_none=True, + row = cast( + Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]], + await self.db_pool.simple_select_one( + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("is_public", "has_auth_chain_index"), + desc="get_room", + allow_none=True, + ), ) + if row is None: + return row + return bool(row[0]), bool(row[1]) async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: """Retrieve room with statistics. @@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) if row: - return RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - ) + return RatelimitOverride(messages_per_second=row[0], burst_count=row[1]) else: return None @@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): join. """ - result = await self.db_pool.simple_select_one( - table="partial_state_rooms", - keyvalues={"room_id": room_id}, - retcols=("join_event_id", "device_lists_stream_id"), - desc="get_join_event_id_for_partial_state", + return cast( + Tuple[str, int], + await self.db_pool.simple_select_one( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + retcols=("join_event_id", "device_lists_stream_id"), + desc="get_join_event_id_for_partial_state", + ), ) - return result["join_event_id"], result["device_lists_stream_id"] def get_un_partial_stated_rooms_token(self, instance_name: str) -> int: return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 1ed7f2d0ef..60d4a9ef30 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): "non-local user %s" % (user_id,), ) - results_dict = await self.db_pool.simple_select_one( - "local_current_membership", - {"room_id": room_id, "user_id": user_id}, - ("membership", "event_id"), - allow_none=True, - desc="get_local_current_membership_for_user_in_room", + results = cast( + Optional[Tuple[str, str]], + await self.db_pool.simple_select_one( + "local_current_membership", + {"room_id": room_id, "user_id": user_id}, + ("membership", "event_id"), + allow_none=True, + desc="get_local_current_membership_for_user_in_room", + ), ) - if not results_dict: + if not results: return None, None - return results_dict.get("membership"), results_dict.get("event_id") + return results @cached(max_entries=500000, iterable=True) async def get_rooms_for_user_with_stream_ordering( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 2225f8272d..563c275a2c 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_position_for_event", ) - return PersistedEventPosition( - row["instance_name"] or "master", row["stream_ordering"] - ) + return PersistedEventPosition(row[1] or "master", row[0]) async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: """The stream token for an event @@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) - return RoomStreamToken( - topological=row["topological_ordering"], stream=row["stream_ordering"] - ) + return RoomStreamToken(topological=row[1], stream=row[0]) async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: """Gets the topological token in a room after or at the given stream @@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = self.db_pool.simple_select_one_txn( - txn, - "events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcols=["stream_ordering", "topological_ordering"], + stream_ordering, topological_ordering = cast( + Tuple[int, int], + self.db_pool.simple_select_one_txn( + txn, + "events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcols=["stream_ordering", "topological_ordering"], + ), ) - # This cannot happen as `allow_none=False`. - assert results is not None - # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( - topological=results["topological_ordering"] - 1, - stream=results["stream_ordering"], + topological=topological_ordering - 1, stream=stream_ordering ) after_token = RoomStreamToken( - topological=results["topological_ordering"], - stream=results["stream_ordering"], + topological=topological_ordering, stream=stream_ordering ) rows, start_token = self._paginate_room_events_txn( diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py index 5555b53575..64543b4d61 100644 --- a/synapse/storage/databases/main/task_scheduler.py +++ b/synapse/storage/databases/main/task_scheduler.py @@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore): Returns: the task if available, `None` otherwise """ - row = await self.db_pool.simple_select_one( - table="scheduled_tasks", - keyvalues={"id": id}, - retcols=( - "id", - "action", - "status", - "timestamp", - "resource_id", - "params", - "result", - "error", + row = cast( + Optional[ScheduledTaskRow], + await self.db_pool.simple_select_one( + table="scheduled_tasks", + keyvalues={"id": id}, + retcols=( + "id", + "action", + "status", + "timestamp", + "resource_id", + "params", + "result", + "error", + ), + allow_none=True, + desc="get_scheduled_task", ), - allow_none=True, - desc="get_scheduled_task", ) - return ( - TaskSchedulerWorkerStore._convert_row_to_task( - ( - row["id"], - row["action"], - row["status"], - row["timestamp"], - row["resource_id"], - row["params"], - row["result"], - row["error"], - ) - ) - if row - else None - ) + return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None async def delete_scheduled_task(self, id: str) -> None: """Delete a specific task from its id. diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index fecddb4144..2d341affaa 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, - retcols=( - "transaction_id", - "origin", - "ts", - "response_code", - "response_json", - "has_been_referenced", - ), + retcols=("response_code", "response_json"), allow_none=True, ) - if result and result["response_code"]: - return result["response_code"], db_to_json(result["response_json"]) + # If the result exists and the response code is non-0. + if result and result[0]: + return result[0], db_to_json(result[1]) else: return None @@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): # check we have a row and retry_last_ts is not null or zero # (retry_last_ts can't be negative) - if result and result["retry_last_ts"]: - return DestinationRetryTimings(**result) + if result and result[1]: + return DestinationRetryTimings( + failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2] + ) else: return None diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 8ab7c42c4a..5b164fed8e 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore): desc="get_ui_auth_session", ) - result["clientdict"] = db_to_json(result["clientdict"]) - - return UIAuthSessionData(session_id, **result) + return UIAuthSessionData( + session_id, + clientdict=db_to_json(result[0]), + uri=result[1], + method=result[2], + description=result[3], + ) async def mark_ui_auth_stage_complete( self, @@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore): self, txn: LoggingTransaction, session_id: str, key: str, value: Any ) -> None: # Get the current value. - result = cast( - Dict[str, Any], - self.db_pool.simple_select_one_txn( - txn, - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - retcols=("serverdict",), - ), + result = self.db_pool.simple_select_one_onecol_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcol="serverdict", ) # Update it and add it back to the database. - serverdict = db_to_json(result["serverdict"]) + serverdict = db_to_json(result) serverdict[key] = value self.db_pool.simple_update_one_txn( @@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore): Raises: StoreError if the session cannot be found. """ - result = await self.db_pool.simple_select_one( + result = await self.db_pool.simple_select_one_onecol( table="ui_auth_sessions", keyvalues={"session_id": session_id}, - retcols=("serverdict",), + retcol="serverdict", desc="get_ui_auth_session_data", ) - serverdict = db_to_json(result["serverdict"]) + serverdict = db_to_json(result) return serverdict.get(key, default) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index d4b86ed7a6..93ec06904a 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -20,7 +20,6 @@ from typing import ( Collection, Iterable, List, - Mapping, Optional, Sequence, Set, @@ -868,13 +867,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) - async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: - return await self.db_pool.simple_select_one( - table="user_directory", - keyvalues={"user_id": user_id}, - retcols=("display_name", "avatar_url"), - allow_none=True, - desc="get_user_in_directory", + async def _get_user_in_directory( + self, user_id: str + ) -> Optional[Tuple[Optional[str], Optional[str]]]: + """ + Fetch the user information in the user directory. + + Returns: + None if the user is unknown, otherwise a tuple of display name and + avatar URL (both of which may be None). + """ + return cast( + Optional[Tuple[Optional[str], Optional[str]]], + await self.db_pool.simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("display_name", "avatar_url"), + allow_none=True, + desc="get_user_in_directory", + ), ) async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None: diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 52d708ad17..76a74b8b8f 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): next_id = self._load_next_id_txn(txn) - txn.call_after(self._mark_id_as_finished, next_id) - txn.call_on_exception(self._mark_id_as_finished, next_id) + txn.call_after(self._mark_ids_as_finished, [next_id]) + txn.call_on_exception(self._mark_ids_as_finished, [next_id]) txn.call_after(self._notifier.notify_replication) # Update the `stream_positions` table with newly updated stream @@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): return self._return_factor * next_id - def _mark_id_as_finished(self, next_id: int) -> None: - """The ID has finished being processed so we should advance the + def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]: + """ + Usage: + + stream_id = stream_id_gen.get_next_txn(txn) + # ... persist event ... + """ + + # If we have a list of instances that are allowed to write to this + # stream, make sure we're in it. + if self._writers and self._instance_name not in self._writers: + raise Exception("Tried to allocate stream ID on non-writer") + + next_ids = self._load_next_mult_id_txn(txn, n) + + txn.call_after(self._mark_ids_as_finished, next_ids) + txn.call_on_exception(self._mark_ids_as_finished, next_ids) + txn.call_after(self._notifier.notify_replication) + + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persisted row with the correct instance name. + if self._writers: + txn.call_after( + run_as_background_process, + "MultiWriterIdGenerator._update_table", + self._db.runInteraction, + "MultiWriterIdGenerator._update_table", + self._update_stream_positions_table_txn, + ) + + return [self._return_factor * next_id for next_id in next_ids] + + def _mark_ids_as_finished(self, next_ids: List[int]) -> None: + """These IDs have finished being processed so we should advance the current position if possible. """ with self._lock: - self._unfinished_ids.discard(next_id) - self._finished_ids.add(next_id) + self._unfinished_ids.difference_update(next_ids) + self._finished_ids.update(next_ids) new_cur: Optional[int] = None @@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): curr, new_cur, self._max_position_of_local_instance ) - self._add_persisted_position(next_id) + # TODO Can we call this for just the last position or somehow batch + # _add_persisted_position. + for next_id in next_ids: + self._add_persisted_position(next_id) def get_current_token(self) -> int: return self.get_persisted_upto_position() @@ -933,8 +972,7 @@ class _MultiWriterCtxManager: exc: Optional[BaseException], tb: Optional[TracebackType], ) -> bool: - for i in self.stream_ids: - self.id_gen._mark_id_as_finished(i) + self.id_gen._mark_ids_as_finished(self.stream_ids) self.notifier.notify_replication() diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 76c56d5434..15e19b15fb 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) - return self.get_success( + row = self.get_success( self.store.db_pool.simple_select_one( table + "_current", {id_col: stat_id}, @@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) + return None if row is None else dict(zip(cols, row)) + def _perform_background_initial_update(self) -> None: # Do the initial population of the stats via the background update self._add_background_updates() diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index b5f15aa7d4..388447eea6 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) profile = self.get_success(self.store._get_user_in_directory(regular_user_id)) assert profile is not None - self.assertTrue(profile["display_name"] == display_name) + self.assertTrue(profile[0] == display_name) def test_handle_local_profile_change_with_deactivated_user(self) -> None: # create user @@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # profile is in directory profile = self.get_success(self.store._get_user_in_directory(r_user_id)) assert profile is not None - self.assertTrue(profile["display_name"] == display_name) + self.assertEqual(profile[0], display_name) # deactivate user self.get_success(self.store.set_user_deactivated_status(r_user_id, True)) diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index 15f5d644e4..a8e7a76b29 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): origin, media_id = self.media_id.split("/") info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) assert info is not None - file_id = info["filesystem_id"] + file_id = info.filesystem_id thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( origin, file_id diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 278808abb5..dac79bd745 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["quarantined_by"]) + self.assertFalse(media_info.quarantined_by) # quarantining channel = self.make_request( @@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertTrue(media_info["quarantined_by"]) + self.assertTrue(media_info.quarantined_by) # remove from quarantine channel = self.make_request( @@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["quarantined_by"]) + self.assertFalse(media_info.quarantined_by) def test_quarantine_protected_media(self) -> None: """ @@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): # verify protection media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertTrue(media_info["safe_from_quarantine"]) + self.assertTrue(media_info.safe_from_quarantine) # quarantining channel = self.make_request( @@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): # verify that is not in quarantine media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["quarantined_by"]) + self.assertFalse(media_info.quarantined_by) class ProtectMediaByIDTestCase(_AdminMediaTests): @@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["safe_from_quarantine"]) + self.assertFalse(media_info.safe_from_quarantine) # protect channel = self.make_request( @@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertTrue(media_info["safe_from_quarantine"]) + self.assertTrue(media_info.safe_from_quarantine) # unprotect channel = self.make_request( @@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["safe_from_quarantine"]) + self.assertFalse(media_info.safe_from_quarantine) class PurgeMediaCacheTestCase(_AdminMediaTests): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 37f37a09d8..42b065d883 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # is in user directory profile = self.get_success(self.store._get_user_in_directory(self.other_user)) assert profile is not None - self.assertTrue(profile["display_name"] == "User") + self.assertEqual(profile[0], "User") # Deactivate user channel = self.make_request( diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index cffbda9a7d..bd59bb50cf 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # # Note that we don't have the UI Auth session ID, so just pull out the single # row. - ui_auth_data = self.get_success( - self.store.db_pool.simple_select_one( - "ui_auth_sessions", keyvalues={}, retcols=("clientdict",) + result = self.get_success( + self.store.db_pool.simple_select_one_onecol( + "ui_auth_sessions", keyvalues={}, retcol="clientdict" ) ) - client_dict = db_to_json(ui_auth_data["clientdict"]) + client_dict = db_to_json(result) self.assertNotIn("new_password", client_dict) @override_config({"rc_3pid_validation": {"burst_count": 3}}) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index ba4e017a0e..b04094b7b3 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertLessEqual(det_data.items(), channel.json_body.items()) # Check the `completed` counter has been incremented and pending is 0 - res = self.get_success( + pending, completed = self.get_success( store.db_pool.simple_select_one( "registration_tokens", keyvalues={"token": token}, retcols=["pending", "completed"], ) ) - self.assertEqual(res["completed"], 1) - self.assertEqual(res["pending"], 0) + self.assertEqual(completed, 1) + self.assertEqual(pending, 0) @override_config({"registration_requires_token": True}) def test_POST_registration_token_invalid(self) -> None: @@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): params1["auth"]["type"] = LoginType.DUMMY self.make_request(b"POST", self.url, params1) # Check pending=0 and completed=1 - res = self.get_success( + pending, completed = self.get_success( store.db_pool.simple_select_one( "registration_tokens", keyvalues={"token": token}, retcols=["pending", "completed"], ) ) - self.assertEqual(res["pending"], 0) - self.assertEqual(res["completed"], 1) + self.assertEqual(pending, 0) + self.assertEqual(completed, 1) # Check auth still fails when using token with session2 channel = self.make_request(b"POST", self.url, params2) diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py index b59d9dfd4d..27a663a23b 100644 --- a/tests/rest/media/test_media_retention.py +++ b/tests/rest/media/test_media_retention.py @@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None: """Given an MXC URI, assert whether it has been purged or not.""" if mxc_uri.server_name == self.hs.config.server.server_name: - found_media_dict = self.get_success( - self.store.get_local_media(mxc_uri.media_id) + found_media = bool( + self.get_success(self.store.get_local_media(mxc_uri.media_id)) ) else: - found_media_dict = self.get_success( - self.store.get_cached_remote_media( - mxc_uri.server_name, mxc_uri.media_id + found_media = bool( + self.get_success( + self.store.get_cached_remote_media( + mxc_uri.server_name, mxc_uri.media_id + ) ) ) if expect_purged: - self.assertIsNone( - found_media_dict, msg=f"{mxc_uri} unexpectedly not purged" - ) + self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged") else: - self.assertIsNotNone( - found_media_dict, + self.assertTrue( + found_media, msg=f"{mxc_uri} unexpectedly purged", ) diff --git a/tests/storage/databases/main/test_cache.py b/tests/storage/databases/main/test_cache.py new file mode 100644 index 0000000000..3f71f5d102 --- /dev/null +++ b/tests/storage/databases/main/test_cache.py @@ -0,0 +1,117 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock, call + +from synapse.storage.database import LoggingTransaction + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.unittest import HomeserverTestCase + + +class CacheInvalidationTestCase(HomeserverTestCase): + def setUp(self) -> None: + super().setUp() + self.store = self.hs.get_datastores().main + + def test_bulk_invalidation(self) -> None: + master_invalidate = Mock() + + self.store._get_cached_user_device.invalidate = master_invalidate + + keys_to_invalidate = [ + ("a", "b"), + ("c", "d"), + ("e", "f"), + ("g", "h"), + ] + + def test_txn(txn: LoggingTransaction) -> None: + self.store._invalidate_cache_and_stream_bulk( + txn, + # This is an arbitrarily chosen cached store function. It was chosen + # because it takes more than one argument. We'll use this later to + # check that the invalidation was actioned over replication. + cache_func=self.store._get_cached_user_device, + key_tuples=keys_to_invalidate, + ) + + self.get_success( + self.store.db_pool.runInteraction( + "test_invalidate_cache_and_stream_bulk", test_txn + ) + ) + + master_invalidate.assert_has_calls( + [call(key_list) for key_list in keys_to_invalidate], + any_order=True, + ) + + +class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase): + def setUp(self) -> None: + super().setUp() + self.store = self.hs.get_datastores().main + + def test_bulk_invalidation_replicates(self) -> None: + """Like test_bulk_invalidation, but also checks the invalidations replicate.""" + master_invalidate = Mock() + worker_invalidate = Mock() + + self.store._get_cached_user_device.invalidate = master_invalidate + worker = self.make_worker_hs("synapse.app.generic_worker") + worker_ds = worker.get_datastores().main + worker_ds._get_cached_user_device.invalidate = worker_invalidate + + keys_to_invalidate = [ + ("a", "b"), + ("c", "d"), + ("e", "f"), + ("g", "h"), + ] + + def test_txn(txn: LoggingTransaction) -> None: + self.store._invalidate_cache_and_stream_bulk( + txn, + # This is an arbitrarily chosen cached store function. It was chosen + # because it takes more than one argument. We'll use this later to + # check that the invalidation was actioned over replication. + cache_func=self.store._get_cached_user_device, + key_tuples=keys_to_invalidate, + ) + + assert self.store._cache_id_gen is not None + initial_token = self.store._cache_id_gen.get_current_token() + self.get_success( + self.database_pool.runInteraction( + "test_invalidate_cache_and_stream_bulk", test_txn + ) + ) + second_token = self.store._cache_id_gen.get_current_token() + + self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate)) + + self.get_success( + worker.get_replication_data_handler().wait_for_stream_position( + "master", "caches", second_token + ) + ) + + master_invalidate.assert_has_calls( + [call(key_list) for key_list in keys_to_invalidate], + any_order=True, + ) + worker_invalidate.assert_has_calls( + [call(key_list) for key_list in keys_to_invalidate], + any_order=True, + ) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index f34b6b2dcf..491e6d5e63 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) + self.assertEqual((1, 2, 3), ret) self.mock_txn.execute.assert_called_once_with( "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertFalse(ret) + self.assertIsNone(ret) @defer.inlineCallbacks def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ce34195a25..d3ffe963d3 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase): ) def test_get_room(self) -> None: - res = self.get_success(self.store.get_room(self.room.to_string())) - assert res is not None - self.assertLessEqual( - { - "room_id": self.room.to_string(), - "creator": self.u_creator.to_string(), - "is_public": True, - }.items(), - res.items(), - ) + room = self.get_success(self.store.get_room(self.room.to_string())) + assert room is not None + self.assertTrue(room[0]) def test_get_room_unknown_room(self) -> None: self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) diff --git a/tests/utils.py b/tests/utils.py index 9be02b8ea7..c44e5cb4ee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -83,11 +83,11 @@ def setupdb() -> None: # Set up in the db db_conn = db_engine.module.connect( + dbname=POSTGRES_BASE_DB, user=POSTGRES_USER, host=POSTGRES_HOST, port=POSTGRES_PORT, password=POSTGRES_PASSWORD, - dbname=POSTGRES_BASE_DB, ) logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") prepare_database(logging_conn, db_engine, None) |