diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4a4d35f77c..a569e8146a 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,7 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple, NamedTuple
from prometheus_client import Counter, Gauge
@@ -53,6 +53,14 @@ pdus_pruned_from_federation_queue = Counter(
logger = logging.getLogger(__name__)
+# All the info we need while iterating the DAG while backfilling
+class BackfillQueueNavigationItem(NamedTuple):
+ depth: int
+ stream_ordering: int
+ event_id: str
+ type: str
+
+
class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
@@ -987,6 +995,117 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
+ async def get_connected_insertion_event_backfill_results(
+ self, event_id: str, limit: int
+ ) -> List[BackfillQueueNavigationItem]:
+ def _get_connected_insertion_event_backfill_results_txn(txn):
+ # Look for the "insertion" events connected to the given event_id
+ connected_insertion_event_query = """
+ SELECT e.depth, e.stream_ordering, i.event_id, e.type FROM insertion_event_edges AS i
+ /* Get the depth of the insertion event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ /* Find an insertion event which points via prev_events to the given event_id */
+ WHERE i.insertion_prev_event_id = ?
+ LIMIT ?
+ """
+
+ txn.execute(
+ connected_insertion_event_query,
+ (event_id, limit),
+ )
+ connected_insertion_event_id_results = txn.fetchall()
+ return [
+ BackfillQueueNavigationItem(
+ depth=row[0],
+ stream_ordering=row[1],
+ event_id=row[2],
+ type=row[3],
+ )
+ for row in connected_insertion_event_id_results
+ ]
+
+ return await self.db_pool.runInteraction(
+ "get_connected_insertion_event_backfill_results",
+ _get_connected_insertion_event_backfill_results_txn,
+ )
+
+ async def get_connected_batch_event_backfill_results(
+ self, insertion_event_id: str, limit: int
+ ) -> List[BackfillQueueNavigationItem]:
+ def _get_connected_batch_event_backfill_results_txn(txn):
+ # Find any batch connections of a given insertion event
+ batch_connection_query = """
+ SELECT e.depth, e.stream_ordering, c.event_id, e.type FROM insertion_events AS i
+ /* Find the batch that connects to the given insertion event */
+ INNER JOIN batch_events AS c
+ ON i.next_batch_id = c.batch_id
+ /* Get the depth of the batch start event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ /* Find an insertion event which matches the given event_id */
+ WHERE i.event_id = ?
+ LIMIT ?
+ """
+
+ # Find any batch connections for the given insertion event
+ txn.execute(
+ batch_connection_query,
+ (insertion_event_id, limit),
+ )
+ batch_start_event_id_results = txn.fetchall()
+ return [
+ BackfillQueueNavigationItem(
+ depth=row[0],
+ stream_ordering=row[1],
+ event_id=row[2],
+ type=row[3],
+ )
+ for row in batch_start_event_id_results
+ ]
+
+ return await self.db_pool.runInteraction(
+ "get_connected_batch_event_backfill_results",
+ _get_connected_batch_event_backfill_results_txn,
+ )
+
+ async def get_connected_prev_event_backfill_results(
+ self, event_id: str, limit: int
+ ) -> List[BackfillQueueNavigationItem]:
+ def _get_connected_prev_event_backfill_results_txn(txn):
+ # Look for the prev_event_id connected to the given event_id
+ connected_prev_event_query = """
+ SELECT depth, stream_ordering, prev_event_id, events.type FROM event_edges
+ /* Get the depth and stream_ordering of the prev_event_id from the events table */
+ INNER JOIN events
+ ON prev_event_id = events.event_id
+ /* Look for an edge which matches the given event_id */
+ WHERE event_edges.event_id = ?
+ AND event_edges.is_state = ?
+ /* Because we can have many events at the same depth,
+ * we want to also tie-break and sort on stream_ordering */
+ ORDER BY depth DESC, stream_ordering DESC
+ LIMIT ?
+ """
+
+ txn.execute(
+ connected_prev_event_query,
+ (event_id, False, limit),
+ )
+ prev_event_id_results = txn.fetchall()
+ return [
+ BackfillQueueNavigationItem(
+ depth=row[0],
+ stream_ordering=row[1],
+ event_id=row[2],
+ type=row[3],
+ )
+ for row in prev_event_id_results
+ ]
+
+ return await self.db_pool.runInteraction(
+ "get_connected_prev_event_backfill_results",
+ _get_connected_prev_event_backfill_results_txn,
+ )
+
async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
|