diff options
author | Patrick Cloke <clokep@users.noreply.github.com> | 2023-11-09 11:13:31 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-09 11:13:31 -0500 |
commit | ab3f1b3b53cdb4e2751b3e8c5eb052d7475be58f (patch) | |
tree | d8f032a9833f5998626f67c9103768d6b77b4945 /synapse/storage | |
parent | Return attrs for more media repo APIs. (#16611) (diff) | |
download | synapse-ab3f1b3b53cdb4e2751b3e8c5eb052d7475be58f.tar.xz |
Convert simple_select_one_txn and simple_select_one to return tuples. (#16612)
Diffstat (limited to 'synapse/storage')
19 files changed, 248 insertions, 241 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0af0507307..eb34de4df5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1597,7 +1597,7 @@ class DatabasePool: retcols: Collection[str], allow_none: Literal[False] = False, desc: str = "simple_select_one", - ) -> Dict[str, Any]: + ) -> Tuple[Any, ...]: ... @overload @@ -1608,7 +1608,7 @@ class DatabasePool: retcols: Collection[str], allow_none: Literal[True] = True, desc: str = "simple_select_one", - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: ... async def simple_select_one( @@ -1618,7 +1618,7 @@ class DatabasePool: retcols: Collection[str], allow_none: bool = False, desc: str = "simple_select_one", - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -2127,7 +2127,7 @@ class DatabasePool: keyvalues: Dict[str, Any], retcols: Collection[str], allow_none: bool = False, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) if keyvalues: @@ -2145,7 +2145,7 @@ class DatabasePool: if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - return dict(zip(retcols, row)) + return row async def simple_delete_one( self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 04d12a876c..775abbac79 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): A dict containing the device information, or `None` if the device does not exist. """ - return await self.db_pool.simple_select_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, - retcols=("user_id", "device_id", "display_name"), - desc="get_device", - allow_none=True, - ) - - async def get_device_opt( - self, user_id: str, device_id: str - ) -> Optional[Dict[str, Any]]: - """Retrieve a device. Only returns devices that are not marked as - hidden. - - Args: - user_id: The ID of the user which owns the device - device_id: The ID of the device to retrieve - Returns: - A dict containing the device information, or None if the device does not exist. - """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", allow_none=True, ) + if row is None: + return None + return {"user_id": row[0], "device_id": row[1], "display_name": row[2]} async def get_devices_by_user( self, user_id: str @@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): retcols=["device_id", "device_data"], allow_none=True, ) - return ( - (row["device_id"], json_decoder.decode(row["device_data"])) if row else None - ) + return (row[0], json_decoder.decode(row[1])) if row else None def _store_dehydrated_device_txn( self, @@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): `FALSE` have not been converted. """ - row = await self.db_pool.simple_select_one( - table="device_lists_changes_converted_stream_position", - keyvalues={}, - retcols=["stream_id", "room_id"], - desc="get_device_change_last_converted_pos", + return cast( + Tuple[int, str], + await self.db_pool.simple_select_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + retcols=["stream_id", "room_id"], + desc="get_device_change_last_converted_pos", + ), ) - return row["stream_id"], row["room_id"] async def set_device_change_last_converted_pos( self, diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index ad904a26a6..fae23c3407 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): # it isn't there. raise StoreError(404, "No backup with that version exists") - result = self.db_pool.simple_select_one_txn( - txn, - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data", "etag"), - allow_none=False, + row = cast( + Tuple[int, str, str, Optional[int]], + self.db_pool.simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + "deleted": 0, + }, + retcols=("version", "algorithm", "auth_data", "etag"), + allow_none=False, + ), ) - assert result is not None # see comment on `simple_select_one_txn` - result["auth_data"] = db_to_json(result["auth_data"]) - result["version"] = str(result["version"]) - if result["etag"] is None: - result["etag"] = 0 - return result + return { + "auth_data": db_to_json(row[2]), + "version": str(row[0]), + "algorithm": row[1], + "etag": 0 if row[3] is None else row[3], + } return await self.db_pool.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 3005e2a2c5..8cb61eaee3 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1266,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if row is None: continue - key_id = row["key_id"] - key_json = row["key_json"] - used = row["used"] + key_id, key_json, used = row # Mark fallback key as used if not already. if not used and mark_as_used: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index f1b0991503..7e992ca4a2 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Check if we have indexed the room so we can use the chain cover # algorithm. room = await self.get_room(room_id) # type: ignore[attr-defined] - if room["has_auth_chain_index"]: + # If the room has an auth chain index. + if room[1]: try: return await self.db_pool.runInteraction( "get_auth_chain_ids_chains", @@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Check if we have indexed the room so we can use the chain cover # algorithm. room = await self.get_room(room_id) # type: ignore[attr-defined] - if room["has_auth_chain_index"]: + # If the room has an auth chain index. + if room[1]: try: return await self.db_pool.runInteraction( "get_auth_chain_difference_chains", @@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) if event_lookup_result is not None: + event_type, depth, stream_ordering = event_lookup_result logger.debug( "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s", room_id, seed_event_id, - event_lookup_result["depth"], - event_lookup_result["stream_ordering"], - event_lookup_result["type"], + depth, + stream_ordering, + event_type, ) - if event_lookup_result["depth"]: - queue.put( - ( - -event_lookup_result["depth"], - -event_lookup_result["stream_ordering"], - seed_event_id, - event_lookup_result["type"], - ) - ) + if depth: + queue.put((-depth, -stream_ordering, seed_event_id, event_type)) while not queue.empty() and len(event_id_results) < limit: try: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 7c34bde3e5..5207cc0f4e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1934,8 +1934,7 @@ class PersistEventsStore: if row is None: return - redacted_relates_to = row["relates_to_id"] - rel_type = row["relation_type"] + redacted_relates_to, rel_type = row self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 5bf864c1fb..4e63a16fa2 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore): if not res: raise SynapseError(404, "Could not find event %s" % (event_id,)) - return int(res["topological_ordering"]), int(res["stream_ordering"]) + return int(res[0]), int(res[1]) async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: """Retrieve the entry with the lowest expiry timestamp in the event_expiry diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 7f99c64f1b..3f80a64dc5 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -208,7 +208,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) if row is None: return None - return LocalMedia(media_id=media_id, **row) + return LocalMedia( + media_id=media_id, + media_type=row[0], + media_length=row[1], + upload_name=row[2], + created_ts=row[3], + quarantined_by=row[4], + url_cache=row[5], + last_access_ts=row[6], + safe_from_quarantine=row[7], + ) async def get_local_media_by_user_paginate( self, @@ -541,7 +551,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) if row is None: return row - return RemoteMedia(media_origin=origin, media_id=media_id, **row) + return RemoteMedia( + media_origin=origin, + media_id=media_id, + media_type=row[0], + media_length=row[1], + upload_name=row[2], + created_ts=row[3], + filesystem_id=row[4], + last_access_ts=row[5], + quarantined_by=row[6], + ) async def store_cached_remote_media( self, @@ -665,11 +685,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): if row is None: return None return ThumbnailInfo( - width=row["thumbnail_width"], - height=row["thumbnail_height"], - method=row["thumbnail_method"], - type=row["thumbnail_type"], - length=row["thumbnail_length"], + width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] ) @trace diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 3ba9cc8853..7ed111f632 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import TYPE_CHECKING, Optional -from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore): return 50 async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: - try: - profile = await self.db_pool.simple_select_one( - table="profiles", - keyvalues={"full_user_id": user_id.to_string()}, - retcols=("displayname", "avatar_url"), - desc="get_profileinfo", - ) - except StoreError as e: - if e.code == 404: - # no match - return ProfileInfo(None, None) - else: - raise - - return ProfileInfo( - avatar_url=profile["avatar_url"], display_name=profile["displayname"] + profile = await self.db_pool.simple_select_one( + table="profiles", + keyvalues={"full_user_id": user_id.to_string()}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + allow_none=True, ) + if profile is None: + # no match + return ProfileInfo(None, None) + + return ProfileInfo(avatar_url=profile[1], display_name=profile[0]) async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 37135d431d..f72a23c584 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore): "before/after rule not found: %s" % (relative_to_rule,) ) - base_priority_class = res["priority_class"] - base_rule_priority = res["priority"] + base_priority_class, base_rule_priority = res if base_priority_class != priority_class: raise InconsistentRuleException( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 56e8eb16a8..3484ce9ef9 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore): allow_none=True, ) - stream_ordering = int(res["stream_ordering"]) if res else None - rx_ts = res["received_ts"] if res else 0 + stream_ordering = int(res[0]) if res else None + rx_ts = res[1] if res else 0 # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 933d76e905..dec9858575 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): account timestamp as milliseconds since the epoch. None if the account has not been renewed using the current token yet. """ - ret_dict = await self.db_pool.simple_select_one( - table="account_validity", - keyvalues={"renewal_token": renewal_token}, - retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], - desc="get_user_from_renewal_token", - ) - - return ( - ret_dict["user_id"], - ret_dict["expiration_ts_ms"], - ret_dict["token_used_ts_ms"], + return cast( + Tuple[str, int, Optional[int]], + await self.db_pool.simple_select_one( + table="account_validity", + keyvalues={"renewal_token": renewal_token}, + retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], + desc="get_user_from_renewal_token", + ), ) async def get_renewal_token_for_user(self, user_id: str) -> str: @@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: user id, or None if no user id/threepid mapping exists """ - ret = self.db_pool.simple_select_one_txn( + return self.db_pool.simple_select_one_onecol_txn( txn, "user_threepids", {"medium": medium, "address": address}, - ["user_id"], + "user_id", True, ) - if ret: - return ret["user_id"] - return None async def user_add_threepid( self, @@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): if res is None: return False + uses_allowed, pending, completed, expiry_time = res + # Check if the token has expired now = self._clock.time_msec() - if res["expiry_time"] and res["expiry_time"] < now: + if expiry_time and expiry_time < now: return False # Check if the token has been used up - if ( - res["uses_allowed"] - and res["pending"] + res["completed"] >= res["uses_allowed"] - ): + if uses_allowed and pending + completed >= uses_allowed: return False # Otherwise, the token is valid @@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. - res = cast( - Dict[str, Any], + pending, completed = cast( + Tuple[int, int], self.db_pool.simple_select_one_txn( txn, "registration_tokens", @@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "registration_tokens", keyvalues={"token": token}, updatevalues={ - "completed": res["completed"] + 1, - "pending": res["pending"] - 1, + "completed": completed + 1, + "pending": pending - 1, }, ) @@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: A dict, or None if token doesn't exist. """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( "registration_tokens", keyvalues={"token": token}, retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], allow_none=True, desc="get_one_registration_token", ) + if row is None: + return None + return { + "token": row[0], + "uses_allowed": row[1], + "pending": row[2], + "completed": row[3], + "expiry_time": row[4], + } async def generate_registration_token( self, length: int, chars: str @@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return None # Get all info about the token so it can be sent in the response - return self.db_pool.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, "registration_tokens", keyvalues={"token": token}, @@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): allow_none=True, ) + if result is None: + return result + + return { + "token": result[0], + "uses_allowed": result[1], + "pending": result[2], + "completed": result[3], + "expiry_time": result[4], + } + return await self.db_pool.runInteraction( "update_registration_token", _update_registration_token_txn ) @@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): keyvalues={"token": token}, updatevalues={"used_ts": ts}, ) - user_id = values["user_id"] - expiry_ts = values["expiry_ts"] - used_ts = values["used_ts"] - auth_provider_id = values["auth_provider_id"] - auth_provider_session_id = values["auth_provider_session_id"] + ( + user_id, + expiry_ts, + used_ts, + auth_provider_id, + auth_provider_session_id, + ) = values # Token was already used if used_ts is not None: @@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): # reason, the next check is on the client secret, which is NOT NULL, # so we don't have to worry about the client secret matching by # accident. - row = {"client_secret": None, "validated_at": None} + row = None, None else: raise ThreepidValidationError("Unknown session_id") - retrieved_client_secret = row["client_secret"] - validated_at = row["validated_at"] + retrieved_client_secret, validated_at = row row = self.db_pool.simple_select_one_txn( txn, @@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): raise ThreepidValidationError( "Validation token not found or has expired" ) - expires = row["expires"] - next_link = row["next_link"] + expires, next_link = row if retrieved_client_secret != client_secret: raise ThreepidValidationError( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index afb880532e..ef26d5d9d3 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: + async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]: """Retrieve a room. Args: room_id: The ID of the room to retrieve. Returns: - A dict containing the room information, or None if the room is unknown. + A tuple containing the room information: + * True if the room is public + * True if the room has an auth chain index + + or None if the room is unknown. """ - return await self.db_pool.simple_select_one( - table="rooms", - keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), - desc="get_room", - allow_none=True, + row = cast( + Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]], + await self.db_pool.simple_select_one( + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("is_public", "has_auth_chain_index"), + desc="get_room", + allow_none=True, + ), ) + if row is None: + return row + return bool(row[0]), bool(row[1]) async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: """Retrieve room with statistics. @@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) if row: - return RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - ) + return RatelimitOverride(messages_per_second=row[0], burst_count=row[1]) else: return None @@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): join. """ - result = await self.db_pool.simple_select_one( - table="partial_state_rooms", - keyvalues={"room_id": room_id}, - retcols=("join_event_id", "device_lists_stream_id"), - desc="get_join_event_id_for_partial_state", + return cast( + Tuple[str, int], + await self.db_pool.simple_select_one( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + retcols=("join_event_id", "device_lists_stream_id"), + desc="get_join_event_id_for_partial_state", + ), ) - return result["join_event_id"], result["device_lists_stream_id"] def get_un_partial_stated_rooms_token(self, instance_name: str) -> int: return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 1ed7f2d0ef..60d4a9ef30 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): "non-local user %s" % (user_id,), ) - results_dict = await self.db_pool.simple_select_one( - "local_current_membership", - {"room_id": room_id, "user_id": user_id}, - ("membership", "event_id"), - allow_none=True, - desc="get_local_current_membership_for_user_in_room", + results = cast( + Optional[Tuple[str, str]], + await self.db_pool.simple_select_one( + "local_current_membership", + {"room_id": room_id, "user_id": user_id}, + ("membership", "event_id"), + allow_none=True, + desc="get_local_current_membership_for_user_in_room", + ), ) - if not results_dict: + if not results: return None, None - return results_dict.get("membership"), results_dict.get("event_id") + return results @cached(max_entries=500000, iterable=True) async def get_rooms_for_user_with_stream_ordering( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 2225f8272d..563c275a2c 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_position_for_event", ) - return PersistedEventPosition( - row["instance_name"] or "master", row["stream_ordering"] - ) + return PersistedEventPosition(row[1] or "master", row[0]) async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: """The stream token for an event @@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) - return RoomStreamToken( - topological=row["topological_ordering"], stream=row["stream_ordering"] - ) + return RoomStreamToken(topological=row[1], stream=row[0]) async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: """Gets the topological token in a room after or at the given stream @@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = self.db_pool.simple_select_one_txn( - txn, - "events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcols=["stream_ordering", "topological_ordering"], + stream_ordering, topological_ordering = cast( + Tuple[int, int], + self.db_pool.simple_select_one_txn( + txn, + "events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcols=["stream_ordering", "topological_ordering"], + ), ) - # This cannot happen as `allow_none=False`. - assert results is not None - # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( - topological=results["topological_ordering"] - 1, - stream=results["stream_ordering"], + topological=topological_ordering - 1, stream=stream_ordering ) after_token = RoomStreamToken( - topological=results["topological_ordering"], - stream=results["stream_ordering"], + topological=topological_ordering, stream=stream_ordering ) rows, start_token = self._paginate_room_events_txn( diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py index 5555b53575..64543b4d61 100644 --- a/synapse/storage/databases/main/task_scheduler.py +++ b/synapse/storage/databases/main/task_scheduler.py @@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore): Returns: the task if available, `None` otherwise """ - row = await self.db_pool.simple_select_one( - table="scheduled_tasks", - keyvalues={"id": id}, - retcols=( - "id", - "action", - "status", - "timestamp", - "resource_id", - "params", - "result", - "error", + row = cast( + Optional[ScheduledTaskRow], + await self.db_pool.simple_select_one( + table="scheduled_tasks", + keyvalues={"id": id}, + retcols=( + "id", + "action", + "status", + "timestamp", + "resource_id", + "params", + "result", + "error", + ), + allow_none=True, + desc="get_scheduled_task", ), - allow_none=True, - desc="get_scheduled_task", ) - return ( - TaskSchedulerWorkerStore._convert_row_to_task( - ( - row["id"], - row["action"], - row["status"], - row["timestamp"], - row["resource_id"], - row["params"], - row["result"], - row["error"], - ) - ) - if row - else None - ) + return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None async def delete_scheduled_task(self, id: str) -> None: """Delete a specific task from its id. diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index fecddb4144..2d341affaa 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, - retcols=( - "transaction_id", - "origin", - "ts", - "response_code", - "response_json", - "has_been_referenced", - ), + retcols=("response_code", "response_json"), allow_none=True, ) - if result and result["response_code"]: - return result["response_code"], db_to_json(result["response_json"]) + # If the result exists and the response code is non-0. + if result and result[0]: + return result[0], db_to_json(result[1]) else: return None @@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): # check we have a row and retry_last_ts is not null or zero # (retry_last_ts can't be negative) - if result and result["retry_last_ts"]: - return DestinationRetryTimings(**result) + if result and result[1]: + return DestinationRetryTimings( + failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2] + ) else: return None diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 8ab7c42c4a..5b164fed8e 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore): desc="get_ui_auth_session", ) - result["clientdict"] = db_to_json(result["clientdict"]) - - return UIAuthSessionData(session_id, **result) + return UIAuthSessionData( + session_id, + clientdict=db_to_json(result[0]), + uri=result[1], + method=result[2], + description=result[3], + ) async def mark_ui_auth_stage_complete( self, @@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore): self, txn: LoggingTransaction, session_id: str, key: str, value: Any ) -> None: # Get the current value. - result = cast( - Dict[str, Any], - self.db_pool.simple_select_one_txn( - txn, - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - retcols=("serverdict",), - ), + result = self.db_pool.simple_select_one_onecol_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcol="serverdict", ) # Update it and add it back to the database. - serverdict = db_to_json(result["serverdict"]) + serverdict = db_to_json(result) serverdict[key] = value self.db_pool.simple_update_one_txn( @@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore): Raises: StoreError if the session cannot be found. """ - result = await self.db_pool.simple_select_one( + result = await self.db_pool.simple_select_one_onecol( table="ui_auth_sessions", keyvalues={"session_id": session_id}, - retcols=("serverdict",), + retcol="serverdict", desc="get_ui_auth_session_data", ) - serverdict = db_to_json(result["serverdict"]) + serverdict = db_to_json(result) return serverdict.get(key, default) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index a9f5d68b63..1a38f3d785 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -20,7 +20,6 @@ from typing import ( Collection, Iterable, List, - Mapping, Optional, Sequence, Set, @@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) - async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: - return await self.db_pool.simple_select_one( - table="user_directory", - keyvalues={"user_id": user_id}, - retcols=("display_name", "avatar_url"), - allow_none=True, - desc="get_user_in_directory", + async def _get_user_in_directory( + self, user_id: str + ) -> Optional[Tuple[Optional[str], Optional[str]]]: + """ + Fetch the user information in the user directory. + + Returns: + None if the user is unknown, otherwise a tuple of display name and + avatar URL (both of which may be None). + """ + return cast( + Optional[Tuple[Optional[str], Optional[str]]], + await self.db_pool.simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("display_name", "avatar_url"), + allow_none=True, + desc="get_user_in_directory", + ), ) async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None: |