diff options
-rw-r--r-- | synapse/handlers/device.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/stream.py | 87 |
2 files changed, 79 insertions, 10 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 67953a3ed9..2b2208f3e6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -216,7 +216,7 @@ class DeviceWorkerHandler: ) # Then work out if any users have since joined - rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) + rooms_changed = await self.store.get_rooms_that_changed(room_ids, from_token.room_key) member_events = await self.store.get_membership_changes_for_user( user_id, from_token.room_key, now_room_key diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 7ab6003f61..c76a54c59b 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -47,11 +47,13 @@ from typing import ( Any, Collection, Dict, + FrozenSet, Iterable, List, Optional, Set, Tuple, + Union, cast, overload, ) @@ -81,6 +83,7 @@ from synapse.types import PersistedEventPosition, RoomStreamToken from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -603,9 +606,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): - list of recent events in the room - stream ordering key for the start of the chunk of events returned. """ - room_ids = self._events_stream_cache.get_entities_changed( - room_ids, from_key.stream - ) + room_ids = await self.get_rooms_that_changed(room_ids, from_key) if not room_ids: return {} @@ -633,18 +634,86 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return results - def get_rooms_that_changed( + async def get_rooms_that_changed( self, room_ids: Collection[str], from_key: RoomStreamToken ) -> Set[str]: """Given a list of rooms and a token, return rooms where there may have been changes. """ + if not room_ids: + return set() + from_id = from_key.stream - return { - room_id - for room_id in room_ids - if self._events_stream_cache.has_entity_changed(room_id, from_id) - } + + rooms_changed = self._events_stream_cache.get_entities_changed( + room_ids, from_id + ) + + # This is the easiest way to test if we actually hit the cache... + if len(rooms_changed) < len(room_ids): + return set(rooms_changed) + + # If we didn't hit the cache let's query the DB for which rooms have had + # events since the given token. + + def get_rooms_that_changed_txn(txn: LoggingTransaction) -> Set[str]: + results: Set[str] = set() + for batch in batch_iter(room_ids, 500): + batch = list(batch) + batch.sort() + + room_id_clause, room_id_args = make_in_list_sql_clause( + self.database_engine, "room_id", batch + ) + + # For each room we want to get the max stream ordering, this is + # annoyingly hard to do in batches that correctly use the indices we + # have, c.f. https://wiki.postgresql.org/wiki/Loose_indexscan. + # + # For a single room we can do a `ORDER BY stream DESC LIMIT 1`, + # which will correctly pull out the latest stream ordering + # efficiently. The following CTE forces postgres to do that one by + # one for each room. It works roughly by: + # 1. Order by room ID and stream ordering DESC with limit 1, this + # will return one room with the maximum stream ordering. + # 2. Run the same query again, but with an added where clause to + # exclude the previous selected rooms (i.e. add a `room_id < + # prev_room_id`). Repeat until no rooms left. + sql = f""" + WITH RECURSIVE t(room_id, stream_ordering) AS ( + ( + SELECT room_id, stream_ordering + FROM events + WHERE room_id = ? + ORDER BY room_id DESC, stream_ordering DESC + LIMIT 1 + ) + UNION ALL + ( + SELECT new_row.* FROM t, LATERAL ( + SELECT room_id, stream_ordering + FROM events + WHERE {room_id_clause} AND events.room_id < t.room_id + ORDER BY room_id DESC, stream_ordering DESC + LIMIT 1 + ) AS new_row + ) + ) + SELECT room_id FROM t WHERE stream_ordering > ? + """ + + args = [batch[-1]] + args.extend(room_id_args) + args.append(from_id) + + txn.execute(sql, args) + results.update(room_id for room_id, in txn) + + return results + + return await self.db_pool.runInteraction( + "get_rooms_that_changed", get_rooms_that_changed_txn + ) async def get_room_events_stream_for_room( self, |