summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2023-02-10 23:29:00 +0000
committerGitHub <noreply@github.com>2023-02-10 23:29:00 +0000
commitd0c713cc85f094c323b2ba3f02d8ac411a7f0705 (patch)
treec26f279f53c76ea9adb74853d11cfd264675eeb5 /synapse/storage/databases
parentSupport for MSC3758: exact_event_match push condition (#14964) (diff)
downloadsynapse-d0c713cc85f094c323b2ba3f02d8ac411a7f0705.tar.xz
Return read-only collections from `@cached` methods (#13755)
It's important that collections returned from `@cached` methods are not
modified, otherwise future retrievals from the cache will return the
modified collection.

This applies to the return values from `@cached` methods and the values
inside the dictionaries returned by `@cachedList` methods. It's not
necessary for the dictionaries returned by `@cachedList` methods
themselves to be read-only.

Signed-off-by: Sean Quah <seanq@matrix.org>
Co-authored-by: David Robertson <davidr@element.io>
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/account_data.py2
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/devices.py17
-rw-r--r--synapse/storage/databases/main/directory.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py25
-rw-r--r--synapse/storage/databases/main/event_federation.py11
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py4
-rw-r--r--synapse/storage/databases/main/receipts.py10
-rw-r--r--synapse/storage/databases/main/registration.py4
-rw-r--r--synapse/storage/databases/main/relations.py7
-rw-r--r--synapse/storage/databases/main/roommember.py19
-rw-r--r--synapse/storage/databases/main/signatures.py6
-rw-r--r--synapse/storage/databases/main/tags.py8
-rw-r--r--synapse/storage/databases/main/user_directory.py4
14 files changed, 71 insertions, 52 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 2d6f02c14f..95567826f2 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -240,7 +240,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     @cached(num_args=2, tree=True)
     async def get_account_data_for_room(
         self, user_id: str, room_id: str
-    ) -> Dict[str, JsonDict]:
+    ) -> Mapping[str, JsonDict]:
         """Get all the client account_data for a user for a room.
 
         Args:
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5fb152c4ff..484db175d0 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
         room_id: str,
         app_service: "ApplicationService",
         cache_context: _CacheContext,
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """
         Get all users in a room that the appservice controls.
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 85c1778a81..1ca66d57d4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -21,6 +21,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -202,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
-    async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+    async def count_devices_by_users(
+        self, user_ids: Optional[Collection[str]] = None
+    ) -> int:
         """Retrieve number of all devices of given users.
         Only returns number of devices that are not marked as hidden.
 
@@ -213,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         """
 
         def count_devices_by_users_txn(
-            txn: LoggingTransaction, user_ids: List[str]
+            txn: LoggingTransaction, user_ids: Collection[str]
         ) -> int:
             sql = """
                 SELECT count(*)
@@ -747,7 +750,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     @cancellable
     async def get_user_devices_from_cache(
         self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
-    ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
+    ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
         Args:
@@ -775,16 +778,18 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
 
         # First fetch all the users which all devices are to be returned.
-        results: Dict[str, Dict[str, JsonDict]] = {}
+        results: Dict[str, Mapping[str, JsonDict]] = {}
         for user_id in user_ids:
             if user_id in user_ids_in_cache:
                 results[user_id] = await self.get_cached_devices_for_user(user_id)
         # Then fetch all device-specific requests, but skip users we've already
         # fetched all devices for.
+        device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
         for user_id, device_id in user_and_device_ids:
             if user_id in user_ids_in_cache and user_id not in user_ids:
                 device = await self._get_cached_user_device(user_id, device_id)
-                results.setdefault(user_id, {})[device_id] = device
+                device_specific_results.setdefault(user_id, {})[device_id] = device
+        results.update(device_specific_results)
 
         set_tag("in_cache", str(results))
         set_tag("not_in_cache", str(user_ids_not_in_cache))
@@ -802,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         return db_to_json(content)
 
     @cached()
-    async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+    async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
         devices = await self.db_pool.simple_select_list(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id},
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 5903fdaf00..44aa181174 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Sequence, Tuple
 
 import attr
 
@@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
         )
 
     @cached(max_entries=5000)
-    async def get_aliases_for_room(self, room_id: str) -> List[str]:
+    async def get_aliases_for_room(self, room_id: str) -> Sequence[str]:
         return await self.db_pool.simple_select_onecol(
             "room_aliases",
             {"room_id": room_id},
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c4ac6c33ba..752dc16e17 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -20,7 +20,9 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
+    Sequence,
     Tuple,
     Union,
     cast,
@@ -691,7 +693,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     @cached(max_entries=10000)
     async def get_e2e_unused_fallback_key_types(
         self, user_id: str, device_id: str
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """Returns the fallback key types that have an unused key.
 
         Args:
@@ -731,7 +733,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         return user_keys.get(key_type)
 
     @cached(num_args=1)
-    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
+    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
         """Dummy function.  Only used to make a cache for
         _get_bare_e2e_cross_signing_keys_bulk.
         """
@@ -744,7 +746,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     )
     async def _get_bare_e2e_cross_signing_keys_bulk(
         self, user_ids: Iterable[str]
-    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+    ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
         the signatures for the calling user need to be fetched.
@@ -765,7 +767,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         )
 
         # The `Optional` comes from the `@cachedList` decorator.
-        return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+        return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
 
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
         self,
@@ -924,7 +926,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     @cancellable
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+    ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -940,11 +942,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
 
         if from_user_id:
-            result = await self.db_pool.runInteraction(
-                "get_e2e_cross_signing_signatures",
-                self._get_e2e_cross_signing_signatures_txn,
-                result,
-                from_user_id,
+            result = cast(
+                Dict[str, Optional[Mapping[str, JsonDict]]],
+                await self.db_pool.runInteraction(
+                    "get_e2e_cross_signing_signatures",
+                    self._get_e2e_cross_signing_signatures_txn,
+                    result,
+                    from_user_id,
+                ),
             )
 
         return result
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index bbee02ab18..ca780cca36 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -22,6 +22,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Sequence,
     Set,
     Tuple,
     cast,
@@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
-    async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
+    async def get_max_depth_of(
+        self, event_ids: Collection[str]
+    ) -> Tuple[Optional[str], int]:
         """Returns the event ID and depth for the event that has the max depth from a set of event IDs
 
         Args:
@@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     @cached(max_entries=5000, iterable=True)
-    async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+    async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
         return await self.db_pool.simple_select_onecol(
             table="event_forward_extremities",
             keyvalues={"room_id": room_id},
@@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     @cancellable
     async def get_forward_extremities_for_room_at_stream_ordering(
         self, room_id: str, stream_ordering: int
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     @cached(max_entries=5000, num_args=2)
     async def _get_forward_extremeties_for_room(
         self, room_id: str, stream_ordering: int
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index db9a24db5e..4b1061e6d7 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage.database import (
@@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
         return await self.db_pool.runInteraction("count_users", _count_users)
 
     @cached(num_args=0)
-    async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
+    async def get_monthly_active_count_by_service(self) -> Mapping[str, int]:
         """Generates current count of monthly active users broken down by service.
         A service is typically an appservice but also includes native matrix users.
         Since the `monthly_active_users` table is populated from the `user_ips` table
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 29972d5204..dddf49c2d5 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -21,7 +21,9 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
+    Sequence,
     Tuple,
     cast,
 )
@@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     async def get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> List[dict]:
+    ) -> Sequence[JsonDict]:
         """Get receipts for a single room for sending to clients.
 
         Args:
@@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cached(tree=True)
     async def _get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> List[JsonDict]:
+    ) -> Sequence[JsonDict]:
         """See get_linearized_receipts_for_room"""
 
         def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     )
     async def _get_linearized_receipts_for_rooms(
         self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
-    ) -> Dict[str, List[JsonDict]]:
+    ) -> Dict[str, Sequence[JsonDict]]:
         if not room_ids:
             return {}
 
@@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     )
     async def get_linearized_receipts_for_all_rooms(
         self, to_key: int, from_key: Optional[int] = None
-    ) -> Dict[str, JsonDict]:
+    ) -> Mapping[str, JsonDict]:
         """Get receipts for all rooms between two stream_ids, up
         to a limit of the latest 100 read receipts.
 
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 31f0f2bd3d..9a55e17624 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
 
 import attr
 
@@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             )
 
     @cached()
-    async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+    async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
         """Deprecated: use get_userinfo_by_id instead"""
 
         def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0018d6f7ab..fa3266c081 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -22,6 +22,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Union,
@@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
         direction: Direction = Direction.BACKWARDS,
         from_token: Optional[StreamToken] = None,
         to_token: Optional[StreamToken] = None,
-    ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
+    ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
         """Get a list of relations for an event, ordered by topological ordering.
 
         Args:
@@ -397,7 +398,9 @@ class RelationsWorkerStore(SQLBaseStore):
         return result is not None
 
     @cached()
-    async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
+    async def get_aggregation_groups_for_event(
+        self, event_id: str
+    ) -> Sequence[JsonDict]:
         raise NotImplementedError()
 
     @cachedList(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ea6a5e2f34..694a5b802c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -24,6 +24,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Union,
@@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return self._known_servers_count
 
     @cached(max_entries=100000, iterable=True)
-    async def get_users_in_room(self, room_id: str) -> List[str]:
+    async def get_users_in_room(self, room_id: str) -> Sequence[str]:
         """Returns a list of users in the room.
 
         Will return inaccurate results for rooms with partial state, since the state for
@@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     @cached()
-    def get_user_in_room_with_profile(
-        self, room_id: str, user_id: str
-    ) -> Dict[str, ProfileInfo]:
+    def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo:
         raise NotImplementedError()
 
     @cachedList(
@@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached(max_entries=100000, iterable=True)
     async def get_users_in_room_with_profiles(
         self, room_id: str
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Mapping[str, ProfileInfo]:
         """Get a mapping from user ID to profile information for all users in a given room.
 
         The profile information comes directly from this room's `m.room.member`
@@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     @cached(max_entries=100000)
-    async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
+    async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
         """Get the details of a room roughly suitable for use by the room
         summary extension to /sync. Useful when lazy loading room members.
         Args:
@@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached()
     async def get_invited_rooms_for_local_user(
         self, user_id: str
-    ) -> List[RoomsForUser]:
+    ) -> Sequence[RoomsForUser]:
         """Get all the rooms the *local* user is invited to.
 
         Args:
@@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return results
 
     @cached(iterable=True)
-    async def get_local_users_in_room(self, room_id: str) -> List[str]:
+    async def get_local_users_in_room(self, room_id: str) -> Sequence[str]:
         """
         Retrieves a list of the current roommembers who are local to the server.
         """
@@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         """Returns the set of users who share a room with `user_id`"""
         room_ids = await self.get_rooms_for_user(user_id)
 
-        user_who_share_room = set()
+        user_who_share_room: Set[str] = set()
         for room_id in room_ids:
             user_ids = await self.get_users_in_room(room_id)
             user_who_share_room.update(user_ids)
@@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return True
 
     @cached(iterable=True, max_entries=10000)
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
         """Get current hosts in room based on current state."""
 
         # First we check if we already have `get_users_in_room` in the cache, as
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 05da15074a..5dcb1fc0b5 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Collection, Dict, List, Tuple
+from typing import Collection, Dict, List, Mapping, Tuple
 
 from unpaddedbase64 import encode_base64
 
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
 
 class SignatureWorkerStore(EventsWorkerStore):
     @cached()
-    def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
+    def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]:
         # This is a dummy function to allow get_event_reference_hashes
         # to use its cache
         raise NotImplementedError()
@@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore):
     )
     async def get_event_reference_hashes(
         self, event_ids: Collection[str]
-    ) -> Dict[str, Dict[str, bytes]]:
+    ) -> Mapping[str, Mapping[str, bytes]]:
         """Get all hashes for given events.
 
         Args:
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index d5500cdd47..c149a9eacb 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, Iterable, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.tcp.streams import AccountDataStream
@@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
 
 class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
-    async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
+    async def get_tags_for_user(
+        self, user_id: str
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
         """Get all the tags for a user.
 
 
@@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
     async def get_updated_tags(
         self, user_id: str, stream_id: int
-    ) -> Dict[str, Dict[str, JsonDict]]:
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
         """Get all the tags for the rooms where the tags have changed since the
         given version
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 14ef5b040d..f6a6fd4079 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -16,9 +16,9 @@ import logging
 import re
 from typing import (
     TYPE_CHECKING,
-    Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Sequence,
     Set,
@@ -586,7 +586,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
     @cached()
-    async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
+    async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
         return await self.db_pool.simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},