summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/device.py2
-rw-r--r--synapse/storage/databases/main/stream.py87
2 files changed, 79 insertions, 10 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 67953a3ed9..2b2208f3e6 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -216,7 +216,7 @@ class DeviceWorkerHandler:
         )
 
         # Then work out if any users have since joined
-        rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
+        rooms_changed = await self.store.get_rooms_that_changed(room_ids, from_token.room_key)
 
         member_events = await self.store.get_membership_changes_for_user(
             user_id, from_token.room_key, now_room_key
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 7ab6003f61..c76a54c59b 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -47,11 +47,13 @@ from typing import (
     Any,
     Collection,
     Dict,
+    FrozenSet,
     Iterable,
     List,
     Optional,
     Set,
     Tuple,
+    Union,
     cast,
     overload,
 )
@@ -81,6 +83,7 @@ from synapse.types import PersistedEventPosition, RoomStreamToken
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.cancellation import cancellable
+from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -603,9 +606,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 - list of recent events in the room
                 - stream ordering key for the start of the chunk of events returned.
         """
-        room_ids = self._events_stream_cache.get_entities_changed(
-            room_ids, from_key.stream
-        )
+        room_ids = await self.get_rooms_that_changed(room_ids, from_key)
 
         if not room_ids:
             return {}
@@ -633,18 +634,86 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return results
 
-    def get_rooms_that_changed(
+    async def get_rooms_that_changed(
         self, room_ids: Collection[str], from_key: RoomStreamToken
     ) -> Set[str]:
         """Given a list of rooms and a token, return rooms where there may have
         been changes.
         """
+        if not room_ids:
+            return set()
+
         from_id = from_key.stream
-        return {
-            room_id
-            for room_id in room_ids
-            if self._events_stream_cache.has_entity_changed(room_id, from_id)
-        }
+
+        rooms_changed = self._events_stream_cache.get_entities_changed(
+            room_ids, from_id
+        )
+
+        # This is the easiest way to test if we actually hit the cache...
+        if len(rooms_changed) < len(room_ids):
+            return set(rooms_changed)
+
+        # If we didn't hit the cache let's query the DB for which rooms have had
+        # events since the given token.
+
+        def get_rooms_that_changed_txn(txn: LoggingTransaction) -> Set[str]:
+            results: Set[str] = set()
+            for batch in batch_iter(room_ids, 500):
+                batch = list(batch)
+                batch.sort()
+
+                room_id_clause, room_id_args = make_in_list_sql_clause(
+                    self.database_engine, "room_id", batch
+                )
+
+                # For each room we want to get the max stream ordering, this is
+                # annoyingly hard to do in batches that correctly use the indices we
+                # have, c.f. https://wiki.postgresql.org/wiki/Loose_indexscan.
+                #
+                # For a single room we can do a `ORDER BY stream DESC LIMIT 1`,
+                # which will correctly pull out the latest stream ordering
+                # efficiently. The following CTE forces postgres to do that one by
+                # one for each room. It works roughly by:
+                #   1. Order by room ID and stream ordering DESC with limit 1, this
+                #      will return one room with the maximum stream ordering.
+                #   2. Run the same query again, but with an added where clause to
+                #      exclude the previous selected rooms (i.e. add a `room_id <
+                #      prev_room_id`). Repeat until no rooms left.
+                sql = f"""
+                    WITH RECURSIVE t(room_id, stream_ordering) AS (
+                        (
+                            SELECT room_id, stream_ordering
+                            FROM events
+                            WHERE room_id = ?
+                            ORDER BY room_id DESC, stream_ordering DESC
+                            LIMIT 1
+                        )
+                        UNION ALL
+                        (
+                            SELECT new_row.* FROM t, LATERAL (
+                                SELECT room_id, stream_ordering
+                                FROM events
+                                WHERE {room_id_clause} AND events.room_id < t.room_id
+                                ORDER BY room_id DESC, stream_ordering DESC
+                                LIMIT 1
+                            ) AS new_row
+                        )
+                    )
+                    SELECT room_id FROM t WHERE stream_ordering > ?
+                """
+
+                args = [batch[-1]]
+                args.extend(room_id_args)
+                args.append(from_id)
+
+                txn.execute(sql, args)
+                results.update(room_id for room_id, in txn)
+
+            return results
+
+        return await self.db_pool.runInteraction(
+            "get_rooms_that_changed", get_rooms_that_changed_txn
+        )
 
     async def get_room_events_stream_for_room(
         self,