summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-09-19 15:26:44 -0400
committerGitHub <noreply@github.com>2023-09-19 15:26:44 -0400
commitd7c89c5908f714aa6a142a89da08fafc597ffe0e (patch)
treecede0e064ca43482813fa7f6142793c85fb68800
parentMerge branch 'release-v1.93' into develop (diff)
downloadsynapse-d7c89c5908f714aa6a142a89da08fafc597ffe0e.tar.xz
Return immutable objects for cachedList decorators (#16350)
-rw-r--r--changelog.d/16350.misc1
-rw-r--r--synapse/appservice/__init__.py6
-rw-r--r--synapse/appservice/api.py6
-rw-r--r--synapse/appservice/scheduler.py18
-rw-r--r--synapse/handlers/appservice.py9
-rw-r--r--synapse/handlers/e2e_keys.py24
-rw-r--r--synapse/handlers/initial_sync.py3
-rw-r--r--synapse/handlers/receipts.py13
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/handlers/typing.py17
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py2
-rw-r--r--synapse/storage/databases/main/appservice.py6
-rw-r--r--synapse/storage/databases/main/devices.py23
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py25
-rw-r--r--synapse/storage/databases/main/events_worker.py5
-rw-r--r--synapse/storage/databases/main/keys.py6
-rw-r--r--synapse/storage/databases/main/presence.py14
-rw-r--r--synapse/storage/databases/main/push_rule.py2
-rw-r--r--synapse/storage/databases/main/receipts.py14
-rw-r--r--synapse/storage/databases/main/relations.py6
-rw-r--r--synapse/storage/databases/main/roommember.py8
-rw-r--r--synapse/storage/databases/main/state.py14
-rw-r--r--synapse/storage/databases/main/transactions.py4
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py4
24 files changed, 134 insertions, 100 deletions
diff --git a/changelog.d/16350.misc b/changelog.d/16350.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16350.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 2260a8f589..6f4aa53c93 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -23,7 +23,7 @@ from netaddr import IPSet
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.types import DeviceListUpdates, JsonDict, UserID
+from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, UserID
 from synapse.util.caches.descriptors import _CacheContext, cached
 
 if TYPE_CHECKING:
@@ -379,8 +379,8 @@ class AppServiceTransaction:
         service: ApplicationService,
         id: int,
         events: Sequence[EventBase],
-        ephemeral: List[JsonDict],
-        to_device_messages: List[JsonDict],
+        ephemeral: List[JsonMapping],
+        to_device_messages: List[JsonMapping],
         one_time_keys_count: TransactionOneTimeKeysCount,
         unused_fallback_keys: TransactionUnusedFallbackKeys,
         device_list_summary: DeviceListUpdates,
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index b1523be208..c42e1f11aa 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -41,7 +41,7 @@ from synapse.events import EventBase
 from synapse.events.utils import SerializeEventConfig, serialize_event
 from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
 from synapse.logging import opentracing
-from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
+from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID
 from synapse.util.caches.response_cache import ResponseCache
 
 if TYPE_CHECKING:
@@ -306,8 +306,8 @@ class ApplicationServiceApi(SimpleHttpClient):
         self,
         service: "ApplicationService",
         events: Sequence[EventBase],
-        ephemeral: List[JsonDict],
-        to_device_messages: List[JsonDict],
+        ephemeral: List[JsonMapping],
+        to_device_messages: List[JsonMapping],
         one_time_keys_count: TransactionOneTimeKeysCount,
         unused_fallback_keys: TransactionUnusedFallbackKeys,
         device_list_summary: DeviceListUpdates,
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 79f95f7653..18a30bc376 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -73,7 +73,7 @@ from synapse.events import EventBase
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.databases.main import DataStore
-from synapse.types import DeviceListUpdates, JsonDict
+from synapse.types import DeviceListUpdates, JsonMapping
 from synapse.util import Clock
 
 if TYPE_CHECKING:
@@ -121,8 +121,8 @@ class ApplicationServiceScheduler:
         self,
         appservice: ApplicationService,
         events: Optional[Collection[EventBase]] = None,
-        ephemeral: Optional[Collection[JsonDict]] = None,
-        to_device_messages: Optional[Collection[JsonDict]] = None,
+        ephemeral: Optional[Collection[JsonMapping]] = None,
+        to_device_messages: Optional[Collection[JsonMapping]] = None,
         device_list_summary: Optional[DeviceListUpdates] = None,
     ) -> None:
         """
@@ -180,9 +180,9 @@ class _ServiceQueuer:
         # dict of {service_id: [events]}
         self.queued_events: Dict[str, List[EventBase]] = {}
         # dict of {service_id: [events]}
-        self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
+        self.queued_ephemeral: Dict[str, List[JsonMapping]] = {}
         # dict of {service_id: [to_device_message_json]}
-        self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
+        self.queued_to_device_messages: Dict[str, List[JsonMapping]] = {}
         # dict of {service_id: [device_list_summary]}
         self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {}
 
@@ -293,8 +293,8 @@ class _ServiceQueuer:
         self,
         service: ApplicationService,
         events: Iterable[EventBase],
-        ephemerals: Iterable[JsonDict],
-        to_device_messages: Iterable[JsonDict],
+        ephemerals: Iterable[JsonMapping],
+        to_device_messages: Iterable[JsonMapping],
     ) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]:
         """
         Given a list of the events, ephemeral messages and to-device messages,
@@ -364,8 +364,8 @@ class _TransactionController:
         self,
         service: ApplicationService,
         events: Sequence[EventBase],
-        ephemeral: Optional[List[JsonDict]] = None,
-        to_device_messages: Optional[List[JsonDict]] = None,
+        ephemeral: Optional[List[JsonMapping]] = None,
+        to_device_messages: Optional[List[JsonMapping]] = None,
         one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None,
         unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
         device_list_summary: Optional[DeviceListUpdates] = None,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 6429545c98..7de7bd3289 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -46,6 +46,7 @@ from synapse.storage.databases.main.directory import RoomAliasMapping
 from synapse.types import (
     DeviceListUpdates,
     JsonDict,
+    JsonMapping,
     RoomAlias,
     RoomStreamToken,
     StreamKeyType,
@@ -397,7 +398,7 @@ class ApplicationServicesHandler:
 
     async def _handle_typing(
         self, service: ApplicationService, new_token: int
-    ) -> List[JsonDict]:
+    ) -> List[JsonMapping]:
         """
         Return the typing events since the given stream token that the given application
         service should receive.
@@ -432,7 +433,7 @@ class ApplicationServicesHandler:
 
     async def _handle_receipts(
         self, service: ApplicationService, new_token: int
-    ) -> List[JsonDict]:
+    ) -> List[JsonMapping]:
         """
         Return the latest read receipts that the given application service should receive.
 
@@ -471,7 +472,7 @@ class ApplicationServicesHandler:
         service: ApplicationService,
         users: Collection[Union[str, UserID]],
         new_token: Optional[int],
-    ) -> List[JsonDict]:
+    ) -> List[JsonMapping]:
         """
         Return the latest presence updates that the given application service should receive.
 
@@ -491,7 +492,7 @@ class ApplicationServicesHandler:
             A list of json dictionaries containing data derived from the presence events
             that should be sent to the given application service.
         """
-        events: List[JsonDict] = []
+        events: List[JsonMapping] = []
         presence_source = self.event_sources.sources.presence
         from_key = await self.store.get_type_stream_id_for_appservice(
             service, "presence"
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index ad075497c8..8c6432035d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
 from synapse.types import (
     JsonDict,
+    JsonMapping,
     UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -272,11 +273,7 @@ class E2eKeysHandler:
                 delay_cancellation=True,
             )
 
-            ret = {"device_keys": results, "failures": failures}
-
-            ret.update(cross_signing_keys)
-
-            return ret
+            return {"device_keys": results, "failures": failures, **cross_signing_keys}
 
     @trace
     async def _query_devices_for_destination(
@@ -408,7 +405,7 @@ class E2eKeysHandler:
     @cancellable
     async def get_cross_signing_keys_from_cache(
         self, query: Iterable[str], from_user_id: Optional[str]
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Dict[str, JsonMapping]]:
         """Get cross-signing keys for users from the database
 
         Args:
@@ -551,16 +548,13 @@ class E2eKeysHandler:
                 self.config.federation.allow_device_name_lookup_over_federation
             ),
         )
-        ret = {"device_keys": res}
 
         # add in the cross-signing keys
         cross_signing_keys = await self.get_cross_signing_keys_from_cache(
             device_keys_query, None
         )
 
-        ret.update(cross_signing_keys)
-
-        return ret
+        return {"device_keys": res, **cross_signing_keys}
 
     async def claim_local_one_time_keys(
         self,
@@ -1127,7 +1121,7 @@ class E2eKeysHandler:
         user_id: str,
         master_key_id: str,
         signed_master_key: JsonDict,
-        stored_master_key: JsonDict,
+        stored_master_key: JsonMapping,
         devices: Dict[str, Dict[str, JsonDict]],
     ) -> List["SignatureListItem"]:
         """Check signatures of a user's master key made by their devices.
@@ -1278,7 +1272,7 @@ class E2eKeysHandler:
 
     async def _get_e2e_cross_signing_verify_key(
         self, user_id: str, key_type: str, from_user_id: Optional[str] = None
-    ) -> Tuple[JsonDict, str, VerifyKey]:
+    ) -> Tuple[JsonMapping, str, VerifyKey]:
         """Fetch locally or remotely query for a cross-signing public key.
 
         First, attempt to fetch the cross-signing public key from storage.
@@ -1333,7 +1327,7 @@ class E2eKeysHandler:
         self,
         user: UserID,
         desired_key_type: str,
-    ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
+    ) -> Optional[Tuple[JsonMapping, str, VerifyKey]]:
         """Queries cross-signing keys for a remote user and saves them to the database
 
         Only the key specified by `key_type` will be returned, while all retrieved keys
@@ -1474,7 +1468,7 @@ def _check_device_signature(
     user_id: str,
     verify_key: VerifyKey,
     signed_device: JsonDict,
-    stored_device: JsonDict,
+    stored_device: JsonMapping,
 ) -> None:
     """Check that a signature on a device or cross-signing key is correct and
     matches the copy of the device/key that we have stored.  Throws an
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 5dc76ef588..5737f8014d 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -32,6 +32,7 @@ from synapse.storage.roommember import RoomsForUser
 from synapse.streams.config import PaginationConfig
 from synapse.types import (
     JsonDict,
+    JsonMapping,
     Requester,
     RoomStreamToken,
     StreamKeyType,
@@ -454,7 +455,7 @@ class InitialSyncHandler:
                 for s in states
             ]
 
-        async def get_receipts() -> List[JsonDict]:
+        async def get_receipts() -> List[JsonMapping]:
             receipts = await self.store.get_linearized_receipts_for_room(
                 room_id, to_key=now_token.receipt_key
             )
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index c7edada353..a7a29b758b 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -19,6 +19,7 @@ from synapse.appservice import ApplicationService
 from synapse.streams import EventSource
 from synapse.types import (
     JsonDict,
+    JsonMapping,
     ReadReceipt,
     StreamKeyType,
     UserID,
@@ -204,15 +205,15 @@ class ReceiptsHandler:
             await self.federation_sender.send_read_receipt(receipt)
 
 
-class ReceiptEventSource(EventSource[int, JsonDict]):
+class ReceiptEventSource(EventSource[int, JsonMapping]):
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self.config = hs.config
 
     @staticmethod
     def filter_out_private_receipts(
-        rooms: Sequence[JsonDict], user_id: str
-    ) -> List[JsonDict]:
+        rooms: Sequence[JsonMapping], user_id: str
+    ) -> List[JsonMapping]:
         """
         Filters a list of serialized receipts (as returned by /sync and /initialSync)
         and removes private read receipts of other users.
@@ -229,7 +230,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
             The same as rooms, but filtered.
         """
 
-        result = []
+        result: List[JsonMapping] = []
 
         # Iterate through each room's receipt content.
         for room in rooms:
@@ -282,7 +283,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
         room_ids: Iterable[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
-    ) -> Tuple[List[JsonDict], int]:
+    ) -> Tuple[List[JsonMapping], int]:
         from_key = int(from_key)
         to_key = self.get_current_key()
 
@@ -301,7 +302,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
 
     async def get_new_events_as(
         self, from_key: int, to_key: int, service: ApplicationService
-    ) -> Tuple[List[JsonDict], int]:
+    ) -> Tuple[List[JsonMapping], int]:
         """Returns a set of new read receipt events that an appservice
         may be interested in.
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 1a4d394eda..7bd42f635f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -235,7 +235,7 @@ class SyncResult:
     archived: List[ArchivedSyncResult]
     to_device: List[JsonDict]
     device_lists: DeviceListUpdates
-    device_one_time_keys_count: JsonDict
+    device_one_time_keys_count: JsonMapping
     device_unused_fallback_key_types: List[str]
 
     def __bool__(self) -> bool:
@@ -1558,7 +1558,7 @@ class SyncHandler:
 
         logger.debug("Fetching OTK data")
         device_id = sync_config.device_id
-        one_time_keys_count: JsonDict = {}
+        one_time_keys_count: JsonMapping = {}
         unused_fallback_key_types: List[str] = []
         if device_id:
             # TODO: We should have a way to let clients differentiate between the states of:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 4b4227003d..bdefa7f26f 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -26,7 +26,14 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.replication.tcp.streams import TypingStream
 from synapse.streams import EventSource
-from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID
+from synapse.types import (
+    JsonDict,
+    JsonMapping,
+    Requester,
+    StrCollection,
+    StreamKeyType,
+    UserID,
+)
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.retryutils import filter_destinations_by_retry_limiter
@@ -487,7 +494,7 @@ class TypingWriterHandler(FollowerTypingHandler):
         raise Exception("Typing writer instance got typing info over replication")
 
 
-class TypingNotificationEventSource(EventSource[int, JsonDict]):
+class TypingNotificationEventSource(EventSource[int, JsonMapping]):
     def __init__(self, hs: "HomeServer"):
         self._main_store = hs.get_datastores().main
         self.clock = hs.get_clock()
@@ -497,7 +504,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
         #
         self.get_typing_handler = hs.get_typing_handler
 
-    def _make_event_for(self, room_id: str) -> JsonDict:
+    def _make_event_for(self, room_id: str) -> JsonMapping:
         typing = self.get_typing_handler()._room_typing[room_id]
         return {
             "type": EduTypes.TYPING,
@@ -507,7 +514,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
 
     async def get_new_events_as(
         self, from_key: int, service: ApplicationService
-    ) -> Tuple[List[JsonDict], int]:
+    ) -> Tuple[List[JsonMapping], int]:
         """Returns a set of new typing events that an appservice
         may be interested in.
 
@@ -551,7 +558,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
         room_ids: Iterable[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
-    ) -> Tuple[List[JsonDict], int]:
+    ) -> Tuple[List[JsonMapping], int]:
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
             handler = self.get_typing_handler()
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 554634579e..14784312dc 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -131,7 +131,7 @@ class BulkPushRuleEvaluator:
     async def _get_rules_for_event(
         self,
         event: EventBase,
-    ) -> Dict[str, FilteredPushRules]:
+    ) -> Mapping[str, FilteredPushRules]:
         """Get the push rules for all users who may need to be notified about
         the event.
 
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 484db175d0..0553a0621a 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -45,7 +45,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import DeviceListUpdates, JsonDict
+from synapse.types import DeviceListUpdates, JsonMapping
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import _CacheContext, cached
 
@@ -268,8 +268,8 @@ class ApplicationServiceTransactionWorkerStore(
         self,
         service: ApplicationService,
         events: Sequence[EventBase],
-        ephemeral: List[JsonDict],
-        to_device_messages: List[JsonDict],
+        ephemeral: List[JsonMapping],
+        to_device_messages: List[JsonMapping],
         one_time_keys_count: TransactionOneTimeKeysCount,
         unused_fallback_keys: TransactionUnusedFallbackKeys,
         device_list_summary: DeviceListUpdates,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 70faf4b1ec..df596f35f9 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -55,7 +55,12 @@ from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     StreamIdGenerator,
 )
-from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
+from synapse.types import (
+    JsonDict,
+    JsonMapping,
+    StrCollection,
+    get_verify_key_from_cross_signing_key,
+)
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
@@ -746,7 +751,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     @cancellable
     async def get_user_devices_from_cache(
         self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
-    ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
+    ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonMapping]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
         Args:
@@ -766,13 +771,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
 
         # First fetch all the users which all devices are to be returned.
-        results: Dict[str, Mapping[str, JsonDict]] = {}
+        results: Dict[str, Mapping[str, JsonMapping]] = {}
         for user_id in user_ids:
             if user_id in user_ids_in_cache:
                 results[user_id] = await self.get_cached_devices_for_user(user_id)
         # Then fetch all device-specific requests, but skip users we've already
         # fetched all devices for.
-        device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
+        device_specific_results: Dict[str, Dict[str, JsonMapping]] = {}
         for user_id, device_id in user_and_device_ids:
             if user_id in user_ids_in_cache and user_id not in user_ids:
                 device = await self._get_cached_user_device(user_id, device_id)
@@ -801,7 +806,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         return user_ids_in_cache
 
     @cached(num_args=2, tree=True)
-    async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
+    async def _get_cached_user_device(
+        self, user_id: str, device_id: str
+    ) -> JsonMapping:
         content = await self.db_pool.simple_select_one_onecol(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id, "device_id": device_id},
@@ -811,7 +818,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         return db_to_json(content)
 
     @cached()
-    async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
+    async def get_cached_devices_for_user(
+        self, user_id: str
+    ) -> Mapping[str, JsonMapping]:
         devices = await self.db_pool.simple_select_list(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id},
@@ -1042,7 +1051,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     )
     async def get_device_list_last_stream_id_for_remotes(
         self, user_ids: Iterable[str]
-    ) -> Dict[str, Optional[str]]:
+    ) -> Mapping[str, Optional[str]]:
         rows = await self.db_pool.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b49dea577c..89fac23f93 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -52,7 +52,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import JsonDict
+from synapse.types import JsonDict, JsonMapping
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.cancellation import cancellable
@@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
     async def get_e2e_device_keys_for_federation_query(
         self, user_id: str
-    ) -> Tuple[int, List[JsonDict]]:
+    ) -> Tuple[int, Sequence[JsonMapping]]:
         """Get all devices (with any device keys) for a user
 
         Returns:
@@ -174,7 +174,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     @cached(iterable=True)
     async def _get_e2e_device_keys_for_federation_query_inner(
         self, user_id: str
-    ) -> List[JsonDict]:
+    ) -> Sequence[JsonMapping]:
         """Get all devices (with any device keys) for a user"""
 
         devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
@@ -578,7 +578,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     @cached(max_entries=10000)
     async def count_e2e_one_time_keys(
         self, user_id: str, device_id: str
-    ) -> Dict[str, int]:
+    ) -> Mapping[str, int]:
         """Count the number of one time keys the server has for a device
         Returns:
             A mapping from algorithm to number of keys for that algorithm.
@@ -812,7 +812,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
     async def get_e2e_cross_signing_key(
         self, user_id: str, key_type: str, from_user_id: Optional[str] = None
-    ) -> Optional[JsonDict]:
+    ) -> Optional[JsonMapping]:
         """Returns a user's cross-signing key.
 
         Args:
@@ -833,7 +833,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         return user_keys.get(key_type)
 
     @cached(num_args=1)
-    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
+    def _get_bare_e2e_cross_signing_keys(
+        self, user_id: str
+    ) -> Mapping[str, JsonMapping]:
         """Dummy function.  Only used to make a cache for
         _get_bare_e2e_cross_signing_keys_bulk.
         """
@@ -846,7 +848,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     )
     async def _get_bare_e2e_cross_signing_keys_bulk(
         self, user_ids: Iterable[str]
-    ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
+    ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
         the signatures for the calling user need to be fetched.
@@ -860,15 +862,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             their user ID will map to None.
 
         """
-        result = await self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_bare_e2e_cross_signing_keys_bulk",
             self._get_bare_e2e_cross_signing_keys_bulk_txn,
             user_ids,
         )
 
-        # The `Optional` comes from the `@cachedList` decorator.
-        return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
-
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
         self,
         txn: LoggingTransaction,
@@ -1026,7 +1025,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     @cancellable
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
+    ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -1043,7 +1042,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         if from_user_id:
             result = cast(
-                Dict[str, Optional[Mapping[str, JsonDict]]],
+                Dict[str, Optional[Mapping[str, JsonMapping]]],
                 await self.db_pool.runInteraction(
                     "get_e2e_cross_signing_signatures",
                     self._get_e2e_cross_signing_signatures_txn,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 1eb313040e..b788d70fc5 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -24,6 +24,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     MutableMapping,
     Optional,
     Set,
@@ -1633,7 +1634,7 @@ class EventsWorkerStore(SQLBaseStore):
         self,
         room_id: str,
         event_ids: Collection[str],
-    ) -> Dict[str, bool]:
+    ) -> Mapping[str, bool]:
         """Helper for have_seen_events
 
         Returns:
@@ -2325,7 +2326,7 @@ class EventsWorkerStore(SQLBaseStore):
     @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids")
     async def get_partial_state_events(
         self, event_ids: Collection[str]
-    ) -> Dict[str, bool]:
+    ) -> Mapping[str, bool]:
         """Checks which of the given events have partial state
 
         Args:
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 41563371dc..889c578b9c 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,7 +16,7 @@
 import itertools
 import json
 import logging
-from typing import Dict, Iterable, Optional, Tuple
+from typing import Dict, Iterable, Mapping, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
@@ -130,7 +130,7 @@ class KeyStore(CacheInvalidationWorkerStore):
     )
     async def get_server_keys_json(
         self, server_name_and_key_ids: Iterable[Tuple[str, str]]
-    ) -> Dict[Tuple[str, str], FetchKeyResult]:
+    ) -> Mapping[Tuple[str, str], FetchKeyResult]:
         """
         Args:
             server_name_and_key_ids:
@@ -200,7 +200,7 @@ class KeyStore(CacheInvalidationWorkerStore):
     )
     async def get_server_keys_json_for_remote(
         self, server_name: str, key_ids: Iterable[str]
-    ) -> Dict[str, Optional[FetchKeyResultForRemote]]:
+    ) -> Mapping[str, Optional[FetchKeyResultForRemote]]:
         """Fetch the cached keys for the given server/key IDs.
 
         If we have multiple entries for a given key ID, returns the most recent.
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index b51d20ac26..194b4e031f 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -11,7 +11,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Tuple,
+    cast,
+)
 
 from synapse.api.presence import PresenceState, UserPresenceState
 from synapse.replication.tcp.streams import PresenceStream
@@ -249,7 +259,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
     )
     async def get_presence_for_users(
         self, user_ids: Iterable[str]
-    ) -> Dict[str, UserPresenceState]:
+    ) -> Mapping[str, UserPresenceState]:
         rows = await self.db_pool.simple_select_many_batch(
             table="presence_stream",
             column="user_id",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index bec0dc2afe..af69944008 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -216,7 +216,7 @@ class PushRulesWorkerStore(
     @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
     async def bulk_get_push_rules(
         self, user_ids: Collection[str]
-    ) -> Dict[str, FilteredPushRules]:
+    ) -> Mapping[str, FilteredPushRules]:
         if not user_ids:
             return {}
 
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index a074c43989..0231f9407b 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -43,7 +43,7 @@ from synapse.storage.util.id_generators import (
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
-from synapse.types import JsonDict
+from synapse.types import JsonDict, JsonMapping
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -218,7 +218,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cached()
     async def _get_receipts_for_user_with_orderings(
         self, user_id: str, receipt_type: str
-    ) -> JsonDict:
+    ) -> JsonMapping:
         """
         Fetch receipts for all rooms that the given user is joined to.
 
@@ -258,7 +258,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     async def get_linearized_receipts_for_rooms(
         self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
-    ) -> List[dict]:
+    ) -> List[JsonMapping]:
         """Get receipts for multiple rooms for sending to clients.
 
         Args:
@@ -287,7 +287,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     async def get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> Sequence[JsonDict]:
+    ) -> Sequence[JsonMapping]:
         """Get receipts for a single room for sending to clients.
 
         Args:
@@ -310,7 +310,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cached(tree=True)
     async def _get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> Sequence[JsonDict]:
+    ) -> Sequence[JsonMapping]:
         """See get_linearized_receipts_for_room"""
 
         def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@@ -353,7 +353,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     )
     async def _get_linearized_receipts_for_rooms(
         self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
-    ) -> Dict[str, Sequence[JsonDict]]:
+    ) -> Mapping[str, Sequence[JsonMapping]]:
         if not room_ids:
             return {}
 
@@ -415,7 +415,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     )
     async def get_linearized_receipts_for_all_rooms(
         self, to_key: int, from_key: Optional[int] = None
-    ) -> Mapping[str, JsonDict]:
+    ) -> Mapping[str, JsonMapping]:
         """Get receipts for all rooms between two stream_ids, up
         to a limit of the latest 100 read receipts.
 
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 96908f14ba..6ba9c9651f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -519,7 +519,7 @@ class RelationsWorkerStore(SQLBaseStore):
     @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
     async def get_applicable_edits(
         self, event_ids: Collection[str]
-    ) -> Dict[str, Optional[EventBase]]:
+    ) -> Mapping[str, Optional[EventBase]]:
         """Get the most recent edit (if any) that has happened for the given
         events.
 
@@ -605,7 +605,7 @@ class RelationsWorkerStore(SQLBaseStore):
     @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
     async def get_thread_summaries(
         self, event_ids: Collection[str]
-    ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
+    ) -> Mapping[str, Optional[Tuple[int, EventBase]]]:
         """Get the number of threaded replies and the latest reply (if any) for the given events.
 
         Args:
@@ -779,7 +779,7 @@ class RelationsWorkerStore(SQLBaseStore):
     @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
     async def get_threads_participated(
         self, event_ids: Collection[str], user_id: str
-    ) -> Dict[str, bool]:
+    ) -> Mapping[str, bool]:
         """Get whether the requesting user participated in the given threads.
 
         This is separate from get_thread_summaries since that can be cached across
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index fff259f74c..7b503dd697 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -191,7 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
     )
     async def get_subset_users_in_room_with_profiles(
         self, room_id: str, user_ids: Collection[str]
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Mapping[str, ProfileInfo]:
         """Get a mapping from user ID to profile information for a list of users
         in a given room.
 
@@ -676,7 +676,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
     )
     async def _get_rooms_for_users(
         self, user_ids: Collection[str]
-    ) -> Dict[str, FrozenSet[str]]:
+    ) -> Mapping[str, FrozenSet[str]]:
         """A batched version of `get_rooms_for_user`.
 
         Returns:
@@ -881,7 +881,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
     )
     async def _get_user_ids_from_membership_event_ids(
         self, event_ids: Iterable[str]
-    ) -> Dict[str, Optional[str]]:
+    ) -> Mapping[str, Optional[str]]:
         """For given set of member event_ids check if they point to a join
         event.
 
@@ -1191,7 +1191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
     )
     async def get_membership_from_event_ids(
         self, member_event_ids: Iterable[str]
-    ) -> Dict[str, Optional[EventIdMembership]]:
+    ) -> Mapping[str, Optional[EventIdMembership]]:
         """Get user_id and membership of a set of event IDs.
 
         Returns:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ebb2ae964f..5eaaff5b68 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,17 @@
 # limitations under the License.
 import collections.abc
 import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+)
 
 import attr
 
@@ -372,7 +382,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     )
     async def _get_state_group_for_events(
         self, event_ids: Collection[str]
-    ) -> Dict[str, int]:
+    ) -> Mapping[str, int]:
         """Returns mapping event_id -> state_group.
 
         Raises:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index efd21b5bfc..8f70eff809 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@
 
 import logging
 from enum import Enum
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Tuple, cast
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -210,7 +210,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
     )
     async def get_destination_retry_timings_batch(
         self, destinations: StrCollection
-    ) -> Dict[str, Optional[DestinationRetryTimings]]:
+    ) -> Mapping[str, Optional[DestinationRetryTimings]]:
         rows = await self.db_pool.simple_select_many_batch(
             table="destinations",
             iterable=destinations,
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index f79006533f..06fcbe5e54 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, Iterable
+from typing import Iterable, Mapping
 
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -40,7 +40,7 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
         return bool(result)
 
     @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
-    async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
+    async def are_users_erased(self, user_ids: Iterable[str]) -> Mapping[str, bool]:
         """
         Checks which users in a list have requested erasure