summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/databases/main/devices.py27
1 files changed, 16 insertions, 11 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py

index 53024bddc3..6191f22cd6 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -27,6 +27,7 @@ from typing import ( Dict, Iterable, List, + Literal, Mapping, Optional, Set, @@ -35,7 +36,6 @@ from typing import ( ) from canonicaljson import encode_canonical_json -from typing_extensions import Literal from synapse.api.constants import EduTypes from synapse.api.errors import Codes, StoreError @@ -282,9 +282,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): "count_devices_by_users", count_devices_by_users_txn, user_ids ) + @cached() async def get_device( self, user_id: str, device_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Mapping[str, Any]]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -670,9 +671,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): result["keys"] = keys device_display_name = None - if ( - self.hs.config.federation.allow_device_name_lookup_over_federation - ): + if self.hs.config.federation.allow_device_name_lookup_over_federation: device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name @@ -917,7 +916,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): from_key, to_key, ) - return {u for u, in rows} + return {u for (u,) in rows} @cancellable async def get_users_whose_devices_changed( @@ -968,7 +967,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): txn.database_engine, "user_id", chunk ) txn.execute(sql % (clause,), [from_key, to_key] + args) - changes.update(user_id for user_id, in txn) + changes.update(user_id for (user_id,) in txn) return changes @@ -1093,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ), ) - results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids} + results: Dict[str, Optional[str]] = dict.fromkeys(user_ids) results.update(rows) return results @@ -1424,7 +1423,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): DELETE FROM device_lists_outbound_last_success WHERE destination = ? AND user_id = ? """ - txn.execute_batch(sql, ((row[0], row[1]) for row in rows)) + txn.execute_batch(sql, [(row[0], row[1]) for row in rows]) logger.info("Pruned %d device list outbound pokes", count) @@ -1520,7 +1519,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): args: List[Any], ) -> Set[str]: txn.execute(sql.format(clause=clause), args) - return {user_id for user_id, in txn} + return {user_id for (user_id,) in txn} changes = set() for chunk in batch_iter(changed_room_ids, 1000): @@ -1560,7 +1559,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): txn: LoggingTransaction, ) -> Set[str]: txn.execute(sql, (from_id, to_id)) - return {room_id for room_id, in txn} + return {room_id for (room_id,) in txn} return await self.db_pool.runInteraction( "get_all_device_list_changes", @@ -1819,6 +1818,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): }, desc="store_device", ) + await self.invalidate_cache_and_stream("get_device", (user_id, device_id)) + if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else @@ -1884,6 +1885,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=device_ids, keyvalues={"user_id": user_id}, ) + self._invalidate_cache_and_stream_bulk( + txn, self.get_device, [(user_id, device_id) for device_id in device_ids] + ) for batch in batch_iter(device_ids, 100): await self.db_pool.runInteraction( @@ -1917,6 +1921,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updatevalues=updates, desc="update_device", ) + await self.invalidate_cache_and_stream("get_device", (user_id, device_id)) async def update_remote_device_list_cache_entry( self, user_id: str, device_id: str, content: JsonDict, stream_id: str