diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 484db175d0..0553a0621a 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -45,7 +45,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import DeviceListUpdates, JsonDict
+from synapse.types import DeviceListUpdates, JsonMapping
from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
@@ -268,8 +268,8 @@ class ApplicationServiceTransactionWorkerStore(
self,
service: ApplicationService,
events: Sequence[EventBase],
- ephemeral: List[JsonDict],
- to_device_messages: List[JsonDict],
+ ephemeral: List[JsonMapping],
+ to_device_messages: List[JsonMapping],
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 70faf4b1ec..df596f35f9 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -55,7 +55,12 @@ from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
+from synapse.types import (
+ JsonDict,
+ JsonMapping,
+ StrCollection,
+ get_verify_key_from_cross_signing_key,
+)
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
@@ -746,7 +751,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@cancellable
async def get_user_devices_from_cache(
self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
- ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
+ ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonMapping]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
@@ -766,13 +771,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
# First fetch all the users which all devices are to be returned.
- results: Dict[str, Mapping[str, JsonDict]] = {}
+ results: Dict[str, Mapping[str, JsonMapping]] = {}
for user_id in user_ids:
if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id)
# Then fetch all device-specific requests, but skip users we've already
# fetched all devices for.
- device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
+ device_specific_results: Dict[str, Dict[str, JsonMapping]] = {}
for user_id, device_id in user_and_device_ids:
if user_id in user_ids_in_cache and user_id not in user_ids:
device = await self._get_cached_user_device(user_id, device_id)
@@ -801,7 +806,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return user_ids_in_cache
@cached(num_args=2, tree=True)
- async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
+ async def _get_cached_user_device(
+ self, user_id: str, device_id: str
+ ) -> JsonMapping:
content = await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -811,7 +818,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return db_to_json(content)
@cached()
- async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
+ async def get_cached_devices_for_user(
+ self, user_id: str
+ ) -> Mapping[str, JsonMapping]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
@@ -1042,7 +1051,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str]
- ) -> Dict[str, Optional[str]]:
+ ) -> Mapping[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b49dea577c..89fac23f93 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -52,7 +52,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import JsonDict
+from synapse.types import JsonDict, JsonMapping
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
@@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
- ) -> Tuple[int, List[JsonDict]]:
+ ) -> Tuple[int, Sequence[JsonMapping]]:
"""Get all devices (with any device keys) for a user
Returns:
@@ -174,7 +174,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(iterable=True)
async def _get_e2e_device_keys_for_federation_query_inner(
self, user_id: str
- ) -> List[JsonDict]:
+ ) -> Sequence[JsonMapping]:
"""Get all devices (with any device keys) for a user"""
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
@@ -578,7 +578,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(max_entries=10000)
async def count_e2e_one_time_keys(
self, user_id: str, device_id: str
- ) -> Dict[str, int]:
+ ) -> Mapping[str, int]:
"""Count the number of one time keys the server has for a device
Returns:
A mapping from algorithm to number of keys for that algorithm.
@@ -812,7 +812,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
- ) -> Optional[JsonDict]:
+ ) -> Optional[JsonMapping]:
"""Returns a user's cross-signing key.
Args:
@@ -833,7 +833,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return user_keys.get(key_type)
@cached(num_args=1)
- def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
+ def _get_bare_e2e_cross_signing_keys(
+ self, user_id: str
+ ) -> Mapping[str, JsonMapping]:
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
@@ -846,7 +848,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str]
- ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
+ ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
@@ -860,15 +862,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
their user ID will map to None.
"""
- result = await self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
)
- # The `Optional` comes from the `@cachedList` decorator.
- return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
-
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
txn: LoggingTransaction,
@@ -1026,7 +1025,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
+ ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -1043,7 +1042,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if from_user_id:
result = cast(
- Dict[str, Optional[Mapping[str, JsonDict]]],
+ Dict[str, Optional[Mapping[str, JsonMapping]]],
await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 1eb313040e..b788d70fc5 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -24,6 +24,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
MutableMapping,
Optional,
Set,
@@ -1633,7 +1634,7 @@ class EventsWorkerStore(SQLBaseStore):
self,
room_id: str,
event_ids: Collection[str],
- ) -> Dict[str, bool]:
+ ) -> Mapping[str, bool]:
"""Helper for have_seen_events
Returns:
@@ -2325,7 +2326,7 @@ class EventsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="is_partial_state_event", list_name="event_ids")
async def get_partial_state_events(
self, event_ids: Collection[str]
- ) -> Dict[str, bool]:
+ ) -> Mapping[str, bool]:
"""Checks which of the given events have partial state
Args:
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 41563371dc..889c578b9c 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,7 +16,7 @@
import itertools
import json
import logging
-from typing import Dict, Iterable, Optional, Tuple
+from typing import Dict, Iterable, Mapping, Optional, Tuple
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -130,7 +130,7 @@ class KeyStore(CacheInvalidationWorkerStore):
)
async def get_server_keys_json(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
- ) -> Dict[Tuple[str, str], FetchKeyResult]:
+ ) -> Mapping[Tuple[str, str], FetchKeyResult]:
"""
Args:
server_name_and_key_ids:
@@ -200,7 +200,7 @@ class KeyStore(CacheInvalidationWorkerStore):
)
async def get_server_keys_json_for_remote(
self, server_name: str, key_ids: Iterable[str]
- ) -> Dict[str, Optional[FetchKeyResultForRemote]]:
+ ) -> Mapping[str, Optional[FetchKeyResultForRemote]]:
"""Fetch the cached keys for the given server/key IDs.
If we have multiple entries for a given key ID, returns the most recent.
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index b51d20ac26..194b4e031f 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -11,7 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ cast,
+)
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
@@ -249,7 +259,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
)
async def get_presence_for_users(
self, user_ids: Iterable[str]
- ) -> Dict[str, UserPresenceState]:
+ ) -> Mapping[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index bec0dc2afe..af69944008 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -216,7 +216,7 @@ class PushRulesWorkerStore(
@cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
async def bulk_get_push_rules(
self, user_ids: Collection[str]
- ) -> Dict[str, FilteredPushRules]:
+ ) -> Mapping[str, FilteredPushRules]:
if not user_ids:
return {}
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index a074c43989..0231f9407b 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -43,7 +43,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict
+from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -218,7 +218,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached()
async def _get_receipts_for_user_with_orderings(
self, user_id: str, receipt_type: str
- ) -> JsonDict:
+ ) -> JsonMapping:
"""
Fetch receipts for all rooms that the given user is joined to.
@@ -258,7 +258,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_rooms(
self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
- ) -> List[dict]:
+ ) -> List[JsonMapping]:
"""Get receipts for multiple rooms for sending to clients.
Args:
@@ -287,7 +287,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> Sequence[JsonDict]:
+ ) -> Sequence[JsonMapping]:
"""Get receipts for a single room for sending to clients.
Args:
@@ -310,7 +310,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> Sequence[JsonDict]:
+ ) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@@ -353,7 +353,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
- ) -> Dict[str, Sequence[JsonDict]]:
+ ) -> Mapping[str, Sequence[JsonMapping]]:
if not room_ids:
return {}
@@ -415,7 +415,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
- ) -> Mapping[str, JsonDict]:
+ ) -> Mapping[str, JsonMapping]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 96908f14ba..6ba9c9651f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -519,7 +519,7 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
async def get_applicable_edits(
self, event_ids: Collection[str]
- ) -> Dict[str, Optional[EventBase]]:
+ ) -> Mapping[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given
events.
@@ -605,7 +605,7 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def get_thread_summaries(
self, event_ids: Collection[str]
- ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
+ ) -> Mapping[str, Optional[Tuple[int, EventBase]]]:
"""Get the number of threaded replies and the latest reply (if any) for the given events.
Args:
@@ -779,7 +779,7 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
async def get_threads_participated(
self, event_ids: Collection[str], user_id: str
- ) -> Dict[str, bool]:
+ ) -> Mapping[str, bool]:
"""Get whether the requesting user participated in the given threads.
This is separate from get_thread_summaries since that can be cached across
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index fff259f74c..7b503dd697 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -191,7 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
)
async def get_subset_users_in_room_with_profiles(
self, room_id: str, user_ids: Collection[str]
- ) -> Dict[str, ProfileInfo]:
+ ) -> Mapping[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for a list of users
in a given room.
@@ -676,7 +676,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
)
async def _get_rooms_for_users(
self, user_ids: Collection[str]
- ) -> Dict[str, FrozenSet[str]]:
+ ) -> Mapping[str, FrozenSet[str]]:
"""A batched version of `get_rooms_for_user`.
Returns:
@@ -881,7 +881,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
)
async def _get_user_ids_from_membership_event_ids(
self, event_ids: Iterable[str]
- ) -> Dict[str, Optional[str]]:
+ ) -> Mapping[str, Optional[str]]:
"""For given set of member event_ids check if they point to a join
event.
@@ -1191,7 +1191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
)
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
- ) -> Dict[str, Optional[EventIdMembership]]:
+ ) -> Mapping[str, Optional[EventIdMembership]]:
"""Get user_id and membership of a set of event IDs.
Returns:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ebb2ae964f..5eaaff5b68 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,17 @@
# limitations under the License.
import collections.abc
import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+)
import attr
@@ -372,7 +382,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
async def _get_state_group_for_events(
self, event_ids: Collection[str]
- ) -> Dict[str, int]:
+ ) -> Mapping[str, int]:
"""Returns mapping event_id -> state_group.
Raises:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index efd21b5bfc..8f70eff809 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@
import logging
from enum import Enum
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Tuple, cast
import attr
from canonicaljson import encode_canonical_json
@@ -210,7 +210,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
)
async def get_destination_retry_timings_batch(
self, destinations: StrCollection
- ) -> Dict[str, Optional[DestinationRetryTimings]]:
+ ) -> Mapping[str, Optional[DestinationRetryTimings]]:
rows = await self.db_pool.simple_select_many_batch(
table="destinations",
iterable=destinations,
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index f79006533f..06fcbe5e54 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Iterable
+from typing import Iterable, Mapping
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -40,7 +40,7 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
return bool(result)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
- async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
+ async def are_users_erased(self, user_ids: Iterable[str]) -> Mapping[str, bool]:
"""
Checks which users in a list have requested erasure
|