diff --git a/changelog.d/17548.misc b/changelog.d/17548.misc
new file mode 100644
index 0000000000..861b241dcd
--- /dev/null
+++ b/changelog.d/17548.misc
@@ -0,0 +1 @@
+Fix performance of device lists in `/key/changes` and sliding sync.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index ce26c91a7b..4f2a9f3a5b 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -267,31 +267,27 @@ class DeviceWorkerHandler:
newly_left_rooms.add(change.room_id)
# We now work out if any other users have since joined or left the rooms
- # the user is currently in. First we filter out rooms that we know
- # haven't changed recently.
- rooms_changed = self.store.get_rooms_that_changed(
- joined_room_ids, from_token.room_key
- )
+ # the user is currently in.
# List of membership changes per room
room_to_deltas: Dict[str, List[StateDelta]] = {}
# The set of event IDs of membership events (so we can fetch their
# associated membership).
memberships_to_fetch: Set[str] = set()
- for room_id in rooms_changed:
- # TODO: Only pull out membership events?
- state_changes = await self.store.get_current_state_deltas_for_room(
- room_id, from_token=from_token.room_key, to_token=now_token.room_key
- )
- for delta in state_changes:
- if delta.event_type != EventTypes.Member:
- continue
- room_to_deltas.setdefault(room_id, []).append(delta)
- if delta.event_id:
- memberships_to_fetch.add(delta.event_id)
- if delta.prev_event_id:
- memberships_to_fetch.add(delta.prev_event_id)
+ # TODO: Only pull out membership events?
+ state_changes = await self.store.get_current_state_deltas_for_rooms(
+ joined_room_ids, from_token=from_token.room_key, to_token=now_token.room_key
+ )
+ for delta in state_changes:
+ if delta.event_type != EventTypes.Member:
+ continue
+
+ room_to_deltas.setdefault(delta.room_id, []).append(delta)
+ if delta.event_id:
+ memberships_to_fetch.add(delta.event_id)
+ if delta.prev_event_id:
+ memberships_to_fetch.add(delta.prev_event_id)
# Fetch all the memberships for the membership events
event_id_to_memberships = await self.store.get_membership_from_event_ids(
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 7d491d1728..eaa13da368 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -26,10 +26,11 @@ import attr
from synapse.logging.opentracing import trace
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.databases.main.stream import _filter_results_by_stream
-from synapse.types import RoomStreamToken
+from synapse.types import RoomStreamToken, StrCollection
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
@@ -200,3 +201,62 @@ class StateDeltasStore(SQLBaseStore):
return await self.db_pool.runInteraction(
"get_current_state_deltas_for_room", get_current_state_deltas_for_room_txn
)
+
+ @trace
+ async def get_current_state_deltas_for_rooms(
+ self,
+ room_ids: StrCollection,
+ from_token: RoomStreamToken,
+ to_token: RoomStreamToken,
+ ) -> List[StateDelta]:
+ """Get the state deltas between two tokens for the set of rooms."""
+
+ room_ids = self._curr_state_delta_stream_cache.get_entities_changed(
+ room_ids, from_token.stream
+ )
+ if not room_ids:
+ return []
+
+ def get_current_state_deltas_for_rooms_txn(
+ txn: LoggingTransaction,
+ room_ids: StrCollection,
+ ) -> List[StateDelta]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
+
+ sql = f"""
+ SELECT instance_name, stream_id, room_id, type, state_key, event_id, prev_event_id
+ FROM current_state_delta_stream
+ WHERE {clause} AND ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ """
+ args.append(from_token.stream)
+ args.append(to_token.get_max_stream_pos())
+
+ txn.execute(sql, args)
+
+ return [
+ StateDelta(
+ stream_id=row[1],
+ room_id=row[2],
+ event_type=row[3],
+ state_key=row[4],
+ event_id=row[5],
+ prev_event_id=row[6],
+ )
+ for row in txn
+ if _filter_results_by_stream(from_token, to_token, row[0], row[1])
+ ]
+
+ results = []
+ for batch in batch_iter(room_ids, 1000):
+ deltas = await self.db_pool.runInteraction(
+ "get_current_state_deltas_for_rooms",
+ get_current_state_deltas_for_rooms_txn,
+ batch,
+ )
+
+ results.extend(deltas)
+
+ return results
|