diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 53024bddc3..6191f22cd6 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -27,6 +27,7 @@ from typing import (
Dict,
Iterable,
List,
+ Literal,
Mapping,
Optional,
Set,
@@ -35,7 +36,6 @@ from typing import (
)
from canonicaljson import encode_canonical_json
-from typing_extensions import Literal
from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError
@@ -282,9 +282,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"count_devices_by_users", count_devices_by_users_txn, user_ids
)
+ @cached()
async def get_device(
self, user_id: str, device_id: str
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[Mapping[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -670,9 +671,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
result["keys"] = keys
device_display_name = None
- if (
- self.hs.config.federation.allow_device_name_lookup_over_federation
- ):
+ if self.hs.config.federation.allow_device_name_lookup_over_federation:
device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
@@ -917,7 +916,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
from_key,
to_key,
)
- return {u for u, in rows}
+ return {u for (u,) in rows}
@cancellable
async def get_users_whose_devices_changed(
@@ -968,7 +967,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
txn.database_engine, "user_id", chunk
)
txn.execute(sql % (clause,), [from_key, to_key] + args)
- changes.update(user_id for user_id, in txn)
+ changes.update(user_id for (user_id,) in txn)
return changes
@@ -1093,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
),
)
- results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
+ results: Dict[str, Optional[str]] = dict.fromkeys(user_ids)
results.update(rows)
return results
@@ -1424,7 +1423,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
DELETE FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ?
"""
- txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
+ txn.execute_batch(sql, [(row[0], row[1]) for row in rows])
logger.info("Pruned %d device list outbound pokes", count)
@@ -1520,7 +1519,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
args: List[Any],
) -> Set[str]:
txn.execute(sql.format(clause=clause), args)
- return {user_id for user_id, in txn}
+ return {user_id for (user_id,) in txn}
changes = set()
for chunk in batch_iter(changed_room_ids, 1000):
@@ -1560,7 +1559,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
txn: LoggingTransaction,
) -> Set[str]:
txn.execute(sql, (from_id, to_id))
- return {room_id for room_id, in txn}
+ return {room_id for (room_id,) in txn}
return await self.db_pool.runInteraction(
"get_all_device_list_changes",
@@ -1819,6 +1818,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
},
desc="store_device",
)
+ await self.invalidate_cache_and_stream("get_device", (user_id, device_id))
+
if not inserted:
# if the device already exists, check if it's a real device, or
# if the device ID is reserved by something else
@@ -1884,6 +1885,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=device_ids,
keyvalues={"user_id": user_id},
)
+ self._invalidate_cache_and_stream_bulk(
+ txn, self.get_device, [(user_id, device_id) for device_id in device_ids]
+ )
for batch in batch_iter(device_ids, 100):
await self.db_pool.runInteraction(
@@ -1917,6 +1921,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
updatevalues=updates,
desc="update_device",
)
+ await self.invalidate_cache_and_stream("get_device", (user_id, device_id))
async def update_remote_device_list_cache_entry(
self, user_id: str, device_id: str, content: JsonDict, stream_id: str
|