summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9576.misc1
-rw-r--r--synapse/federation/federation_server.py6
-rw-r--r--synapse/handlers/federation.py6
-rw-r--r--synapse/storage/databases/main/event_federation.py148
-rw-r--r--tests/storage/test_event_federation.py76
5 files changed, 226 insertions, 11 deletions
diff --git a/changelog.d/9576.misc b/changelog.d/9576.misc
new file mode 100644
index 0000000000..bc257d05b7
--- /dev/null
+++ b/changelog.d/9576.misc
@@ -0,0 +1 @@
+Improve efficiency of calculating the auth chain in large rooms.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ffc735ba25..06c5e7a9e0 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -447,7 +447,7 @@ class FederationServer(FederationBase):
 
     async def _on_state_ids_request_compute(self, room_id, event_id):
         state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
-        auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
+        auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
         return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
 
     async def _on_context_state_request_compute(
@@ -460,7 +460,9 @@ class FederationServer(FederationBase):
         else:
             pdus = (await self.state.get_current_state(room_id)).values()
 
-        auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
+        auth_chain = await self.store.get_auth_chain(
+            room_id, [pdu.event_id for pdu in pdus]
+        )
 
         return {
             "pdus": [pdu.get_pdu_json() for pdu in pdus],
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 2ead626a4d..3fe02b7195 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1317,7 +1317,7 @@ class FederationHandler(BaseHandler):
     async def on_event_auth(self, event_id: str) -> List[EventBase]:
         event = await self.store.get_event(event_id)
         auth = await self.store.get_auth_chain(
-            list(event.auth_event_ids()), include_given=True
+            event.room_id, list(event.auth_event_ids()), include_given=True
         )
         return list(auth)
 
@@ -1580,7 +1580,7 @@ class FederationHandler(BaseHandler):
         prev_state_ids = await context.get_prev_state_ids()
 
         state_ids = list(prev_state_ids.values())
-        auth_chain = await self.store.get_auth_chain(state_ids)
+        auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
 
         state = await self.store.get_events(list(prev_state_ids.values()))
 
@@ -2219,7 +2219,7 @@ class FederationHandler(BaseHandler):
 
         # Now get the current auth_chain for the event.
         local_auth_chain = await self.store.get_auth_chain(
-            list(event.auth_event_ids()), include_given=True
+            room_id, list(event.auth_event_ids()), include_given=True
         )
 
         # TODO: Check if we would now reject event_id. If so we need to tell
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 18ddb92fcc..332193ad1c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         )  # type: LruCache[str, List[Tuple[str, int]]]
 
     async def get_auth_chain(
-        self, event_ids: Collection[str], include_given: bool = False
+        self, room_id: str, event_ids: Collection[str], include_given: bool = False
     ) -> List[EventBase]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
+            room_id: The room the event is in.
             event_ids: state events
             include_given: include the given events in result
 
@@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             list of events
         """
         event_ids = await self.get_auth_chain_ids(
-            event_ids, include_given=include_given
+            room_id, event_ids, include_given=include_given
         )
         return await self.get_events_as_list(event_ids)
 
     async def get_auth_chain_ids(
         self,
+        room_id: str,
         event_ids: Collection[str],
         include_given: bool = False,
     ) -> List[str]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
+            room_id: The room the event is in.
             event_ids: state events
             include_given: include the given events in result
 
         Returns:
-            An awaitable which resolve to a list of event_ids
+            list of event_ids
         """
+
+        # Check if we have indexed the room so we can use the chain cover
+        # algorithm.
+        room = await self.get_room(room_id)
+        if room["has_auth_chain_index"]:
+            try:
+                return await self.db_pool.runInteraction(
+                    "get_auth_chain_ids_chains",
+                    self._get_auth_chain_ids_using_cover_index_txn,
+                    room_id,
+                    event_ids,
+                    include_given,
+                )
+            except _NoChainCoverIndex:
+                # For whatever reason we don't actually have a chain cover index
+                # for the events in question, so we fall back to the old method.
+                pass
+
         return await self.db_pool.runInteraction(
             "get_auth_chain_ids",
             self._get_auth_chain_ids_txn,
@@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             include_given,
         )
 
+    def _get_auth_chain_ids_using_cover_index_txn(
+        self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+    ) -> List[str]:
+        """Calculates the auth chain IDs using the chain index."""
+
+        # First we look up the chain ID/sequence numbers for the given events.
+
+        initial_events = set(event_ids)
+
+        # All the events that we've found that are reachable from the events.
+        seen_events = set()  # type: Set[str]
+
+        # A map from chain ID to max sequence number of the given events.
+        event_chains = {}  # type: Dict[int, int]
+
+        sql = """
+            SELECT event_id, chain_id, sequence_number
+            FROM event_auth_chains
+            WHERE %s
+        """
+        for batch in batch_iter(initial_events, 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "event_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for event_id, chain_id, sequence_number in txn:
+                seen_events.add(event_id)
+                event_chains[chain_id] = max(
+                    sequence_number, event_chains.get(chain_id, 0)
+                )
+
+        # Check that we actually have a chain ID for all the events.
+        events_missing_chain_info = initial_events.difference(seen_events)
+        if events_missing_chain_info:
+            # This can happen due to e.g. downgrade/upgrade of the server. We
+            # raise an exception and fall back to the previous algorithm.
+            logger.info(
+                "Unexpectedly found that events don't have chain IDs in room %s: %s",
+                room_id,
+                events_missing_chain_info,
+            )
+            raise _NoChainCoverIndex(room_id)
+
+        # Now we look up all links for the chains we have, adding chains that
+        # are reachable from any event.
+        sql = """
+            SELECT
+                origin_chain_id, origin_sequence_number,
+                target_chain_id, target_sequence_number
+            FROM event_auth_chain_links
+            WHERE %s
+        """
+
+        # A map from chain ID to max sequence number *reachable* from any event ID.
+        chains = {}  # type: Dict[int, int]
+
+        # Add all linked chains reachable from initial set of chains.
+        for batch in batch_iter(event_chains, 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "origin_chain_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for (
+                origin_chain_id,
+                origin_sequence_number,
+                target_chain_id,
+                target_sequence_number,
+            ) in txn:
+                # chains are only reachable if the origin sequence number of
+                # the link is less than the max sequence number in the
+                # origin chain.
+                if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
+                    chains[target_chain_id] = max(
+                        target_sequence_number,
+                        chains.get(target_chain_id, 0),
+                    )
+
+        # Add the initial set of chains, excluding the sequence corresponding to
+        # initial event.
+        for chain_id, seq_no in event_chains.items():
+            chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
+
+        # Now for each chain we figure out the maximum sequence number reachable
+        # from *any* event ID. Events with a sequence less than that are in the
+        # auth chain.
+        if include_given:
+            results = initial_events
+        else:
+            results = set()
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # We can use `execute_values` to efficiently fetch the gaps when
+            # using postgres.
+            sql = """
+                SELECT event_id
+                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
+                WHERE
+                    c.chain_id = l.chain_id
+                    AND sequence_number <= max_seq
+            """
+
+            rows = txn.execute_values(sql, chains.items())
+            results.update(r for r, in rows)
+        else:
+            # For SQLite we just fall back to doing a noddy for loop.
+            sql = """
+                SELECT event_id FROM event_auth_chains
+                WHERE chain_id = ? AND sequence_number <= ?
+            """
+            for chain_id, max_no in chains.items():
+                txn.execute(sql, (chain_id, max_no))
+                results.update(r for r, in txn)
+
+        return list(results)
+
     def _get_auth_chain_ids_txn(
         self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
     ) -> List[str]:
+        """Calculates the auth chain IDs.
+
+        This is used when we don't have a cover index for the room.
+        """
         if include_given:
             results = set(event_ids)
         else:
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 06000f81a6..d597d712d6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
         self.assertTrue(r == [room2] or r == [room3])
 
-    @parameterized.expand([(True,), (False,)])
-    def test_auth_difference(self, use_chain_cover_index: bool):
+    def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
         room_id = "@ROOM:local"
 
         # The silly auth graph we use to test the auth difference algorithm,
@@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             "j": 1,
         }
 
-        # Mark the room as not having a cover index
+        # Mark the room as maybe having a cover index.
 
         def store_room(txn):
             self.store.db_pool.simple_insert_txn(
@@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             )
         )
 
+        return room_id
+
+    @parameterized.expand([(True,), (False,)])
+    def test_auth_chain_ids(self, use_chain_cover_index: bool):
+        room_id = self._setup_auth_chain(use_chain_cover_index)
+
+        # a and b have the same auth chain.
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
+        self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
+        self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+        auth_chain_ids = self.get_success(
+            self.store.get_auth_chain_ids(room_id, ["a", "b"])
+        )
+        self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
+        self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+        # d and e have the same auth chain.
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
+        self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
+        self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
+        self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
+        self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
+
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
+        self.assertEqual(auth_chain_ids, ["k"])
+
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
+        self.assertEqual(auth_chain_ids, ["j"])
+
+        # j and k have no parents.
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
+        self.assertEqual(auth_chain_ids, [])
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
+        self.assertEqual(auth_chain_ids, [])
+
+        # More complex input sequences.
+        auth_chain_ids = self.get_success(
+            self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
+        )
+        self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+        auth_chain_ids = self.get_success(
+            self.store.get_auth_chain_ids(room_id, ["h", "i"])
+        )
+        self.assertCountEqual(auth_chain_ids, ["k", "j"])
+
+        # e gets returned even though include_given is false, but it is in the
+        # auth chain of b.
+        auth_chain_ids = self.get_success(
+            self.store.get_auth_chain_ids(room_id, ["b", "e"])
+        )
+        self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+        # Test include_given.
+        auth_chain_ids = self.get_success(
+            self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
+        )
+        self.assertCountEqual(auth_chain_ids, ["i", "j"])
+
+    @parameterized.expand([(True,), (False,)])
+    def test_auth_difference(self, use_chain_cover_index: bool):
+        room_id = self._setup_auth_chain(use_chain_cover_index)
+
         # Now actually test that various combinations give the right result:
 
         difference = self.get_success(