diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 67953a3ed9..2b2208f3e6 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -216,7 +216,7 @@ class DeviceWorkerHandler:
)
# Then work out if any users have since joined
- rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
+ rooms_changed = await self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = await self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_room_key
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 7ab6003f61..c76a54c59b 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -47,11 +47,13 @@ from typing import (
Any,
Collection,
Dict,
+ FrozenSet,
Iterable,
List,
Optional,
Set,
Tuple,
+ Union,
cast,
overload,
)
@@ -81,6 +83,7 @@ from synapse.types import PersistedEventPosition, RoomStreamToken
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
+from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -603,9 +606,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
"""
- room_ids = self._events_stream_cache.get_entities_changed(
- room_ids, from_key.stream
- )
+ room_ids = await self.get_rooms_that_changed(room_ids, from_key)
if not room_ids:
return {}
@@ -633,18 +634,86 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
- def get_rooms_that_changed(
+ async def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: RoomStreamToken
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
"""
+ if not room_ids:
+ return set()
+
from_id = from_key.stream
- return {
- room_id
- for room_id in room_ids
- if self._events_stream_cache.has_entity_changed(room_id, from_id)
- }
+
+ rooms_changed = self._events_stream_cache.get_entities_changed(
+ room_ids, from_id
+ )
+
+ # This is the easiest way to test if we actually hit the cache...
+ if len(rooms_changed) < len(room_ids):
+ return set(rooms_changed)
+
+ # If we didn't hit the cache let's query the DB for which rooms have had
+ # events since the given token.
+
+ def get_rooms_that_changed_txn(txn: LoggingTransaction) -> Set[str]:
+ results: Set[str] = set()
+ for batch in batch_iter(room_ids, 500):
+ batch = list(batch)
+ batch.sort()
+
+ room_id_clause, room_id_args = make_in_list_sql_clause(
+ self.database_engine, "room_id", batch
+ )
+
+ # For each room we want to get the max stream ordering, this is
+ # annoyingly hard to do in batches that correctly use the indices we
+ # have, c.f. https://wiki.postgresql.org/wiki/Loose_indexscan.
+ #
+ # For a single room we can do a `ORDER BY stream DESC LIMIT 1`,
+ # which will correctly pull out the latest stream ordering
+ # efficiently. The following CTE forces postgres to do that one by
+ # one for each room. It works roughly by:
+ # 1. Order by room ID and stream ordering DESC with limit 1, this
+ # will return one room with the maximum stream ordering.
+ # 2. Run the same query again, but with an added where clause to
+ # exclude the previous selected rooms (i.e. add a `room_id <
+ # prev_room_id`). Repeat until no rooms left.
+ sql = f"""
+ WITH RECURSIVE t(room_id, stream_ordering) AS (
+ (
+ SELECT room_id, stream_ordering
+ FROM events
+ WHERE room_id = ?
+ ORDER BY room_id DESC, stream_ordering DESC
+ LIMIT 1
+ )
+ UNION ALL
+ (
+ SELECT new_row.* FROM t, LATERAL (
+ SELECT room_id, stream_ordering
+ FROM events
+ WHERE {room_id_clause} AND events.room_id < t.room_id
+ ORDER BY room_id DESC, stream_ordering DESC
+ LIMIT 1
+ ) AS new_row
+ )
+ )
+ SELECT room_id FROM t WHERE stream_ordering > ?
+ """
+
+ args = [batch[-1]]
+ args.extend(room_id_args)
+ args.append(from_id)
+
+ txn.execute(sql, args)
+ results.update(room_id for room_id, in txn)
+
+ return results
+
+ return await self.db_pool.runInteraction(
+ "get_rooms_that_changed", get_rooms_that_changed_txn
+ )
async def get_room_events_stream_for_room(
self,
|