diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8dbcb3f5a0..f4410b5c02 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -70,10 +70,7 @@ from synapse.types import (
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.stream_change_cache import (
- AllEntitiesChangedResult,
- StreamChangeCache,
-)
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -132,6 +129,20 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
prefilled_cache=device_list_prefill,
)
+ device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_changes_in_room",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_room_stream_cache = StreamChangeCache(
+ "DeviceListRoomStreamChangeCache",
+ min_device_list_room_id,
+ prefilled_cache=device_list_room_prefill,
+ )
+
(
user_signature_stream_prefill,
user_signature_stream_list_id,
@@ -209,6 +220,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
row.entity, token
)
+ def device_lists_in_rooms_have_changed(
+ self, room_ids: StrCollection, token: int
+ ) -> None:
+ "Record that device lists have changed in rooms"
+ for room_id in room_ids:
+ self._device_list_room_stream_cache.entity_has_changed(room_id, token)
+
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
@@ -832,16 +850,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
return {device[0]: db_to_json(device[1]) for device in devices}
- def get_cached_device_list_changes(
- self,
- from_key: int,
- ) -> AllEntitiesChangedResult:
- """Get set of users whose devices have changed since `from_key`, or None
- if that information is not in our cache.
- """
-
- return self._device_list_stream_cache.get_all_entities_changed(from_key)
-
@cancellable
async def get_all_devices_changed(
self,
@@ -1457,7 +1465,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@cancellable
async def get_device_list_changes_in_rooms(
- self, room_ids: Collection[str], from_id: int
+ self, room_ids: Collection[str], from_id: int, to_id: int
) -> Optional[Set[str]]:
"""Return the set of users whose devices have changed in the given rooms
since the given stream ID.
@@ -1473,9 +1481,15 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
if min_stream_id > from_id:
return None
+ changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
+ room_ids, from_id
+ )
+ if not changed_room_ids:
+ return set()
+
sql = """
SELECT DISTINCT user_id FROM device_lists_changes_in_room
- WHERE {clause} AND stream_id >= ?
+ WHERE {clause} AND stream_id > ? AND stream_id <= ?
"""
def _get_device_list_changes_in_rooms_txn(
@@ -1487,11 +1501,12 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return {user_id for user_id, in txn}
changes = set()
- for chunk in batch_iter(room_ids, 1000):
+ for chunk in batch_iter(changed_room_ids, 1000):
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", chunk
)
args.append(from_id)
+ args.append(to_id)
changes |= await self.db_pool.runInteraction(
"get_device_list_changes_in_rooms",
@@ -1502,6 +1517,34 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return changes
+ async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
+ """Return the set of rooms where devices have changed since the given
+ stream ID.
+
+ Will raise an exception if the given stream ID is too old.
+ """
+
+ min_stream_id = await self._get_min_device_lists_changes_in_room()
+
+ if min_stream_id > from_id:
+ raise Exception("stream ID is too old")
+
+ sql = """
+ SELECT DISTINCT room_id FROM device_lists_changes_in_room
+ WHERE stream_id > ? AND stream_id <= ?
+ """
+
+ def _get_all_device_list_changes_txn(
+ txn: LoggingTransaction,
+ ) -> Set[str]:
+ txn.execute(sql, (from_id, to_id))
+ return {room_id for room_id, in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_all_device_list_changes",
+ _get_all_device_list_changes_txn,
+ )
+
async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
@@ -1962,8 +2005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
async def add_device_change_to_streams(
self,
user_id: str,
- device_ids: Collection[str],
- room_ids: Collection[str],
+ device_ids: StrCollection,
+ room_ids: StrCollection,
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@@ -2122,8 +2165,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self,
txn: LoggingTransaction,
user_id: str,
- device_ids: Iterable[str],
- room_ids: Collection[str],
+ device_ids: StrCollection,
+ room_ids: StrCollection,
stream_ids: List[int],
context: Dict[str, str],
) -> None:
@@ -2161,6 +2204,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
],
)
+ txn.call_after(
+ self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
+ )
+
async def get_uncoverted_outbound_room_pokes(
self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
|