diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 0b6d1f2b05..3f0b2f5d84 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -282,9 +282,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"count_devices_by_users", count_devices_by_users_txn, user_ids
)
+ @cached()
async def get_device(
self, user_id: str, device_id: str
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[Mapping[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -1817,6 +1818,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
},
desc="store_device",
)
+ await self.invalidate_cache_and_stream("get_device", (user_id, device_id))
+
if not inserted:
# if the device already exists, check if it's a real device, or
# if the device ID is reserved by something else
@@ -1882,6 +1885,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=device_ids,
keyvalues={"user_id": user_id},
)
+ self._invalidate_cache_and_stream_bulk(
+ txn, self.get_device, [(user_id, device_id) for device_id in device_ids]
+ )
for batch in batch_iter(device_ids, 100):
await self.db_pool.runInteraction(
@@ -1915,6 +1921,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
updatevalues=updates,
desc="update_device",
)
+ await self.invalidate_cache_and_stream("get_device", (user_id, device_id))
async def update_remote_device_list_cache_entry(
self, user_id: str, device_id: str, content: JsonDict, stream_id: str
|