diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c478e0bc5c..e28c74daf0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -16,6 +16,7 @@
"""Contains handlers for federation events."""
import logging
+from queue import Empty, PriorityQueue
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
@@ -1041,6 +1042,135 @@ class FederationHandler:
else:
return []
+ async def get_backfill_events(
+ self, room_id: str, event_id_list: list, limit: int
+ ) -> List[EventBase]:
+ event_id_results = set()
+
+ # In a PriorityQueue, the lowest valued entries are retrieved first.
+ # We're using depth as the priority in the queue and tie-break based on
+ # stream_ordering. Depth is lowest at the oldest-in-time message and
+ # highest and newest-in-time message. We add events to the queue with a
+ # negative depth so that we process the newest-in-time messages first
+ # going backwards in time. stream_ordering follows the same pattern.
+ queue = PriorityQueue()
+
+ seed_events = await self.store.get_events_as_list(event_id_list)
+ for seed_event in seed_events:
+ # Make sure the seed event actually pertains to this room. We also
+ # need to make sure the depth is available since our whole DAG
+ # navigation here depends on depth.
+ if seed_event.room_id == room_id and seed_event.depth:
+ queue.put(
+ (
+ -seed_event.depth,
+ -seed_event.internal_metadata.stream_ordering,
+ seed_event.event_id,
+ seed_event.type,
+ )
+ )
+
+ while not queue.empty() and len(event_id_results) < limit:
+ try:
+ _, _, event_id, event_type = queue.get_nowait()
+ except Empty:
+ break
+
+ if event_id in event_id_results:
+ continue
+
+ event_id_results.add(event_id)
+
+ if self.hs.config.experimental.msc2716_enabled:
+ # Try and find any potential historical batches of message history.
+ #
+ # First we look for an insertion event connected to the current
+ # event (by prev_event). If we find any, we'll add them to the queue
+ # and navigate up the DAG like normal in the next iteration of the
+ # loop.
+ connected_insertion_event_backfill_results = (
+ await self.store.get_connected_insertion_event_backfill_results(
+ event_id, limit - len(event_id_results)
+ )
+ )
+ logger.debug(
+ "_get_backfill_events: connected_insertion_event_backfill_results=%s",
+ connected_insertion_event_backfill_results,
+ )
+ for (
+ connected_insertion_event_backfill_item
+ ) in connected_insertion_event_backfill_results:
+ if (
+ connected_insertion_event_backfill_item.event_id
+ not in event_id_results
+ ):
+ queue.put(
+ (
+ -connected_insertion_event_backfill_item.depth,
+ -connected_insertion_event_backfill_item.stream_ordering,
+ connected_insertion_event_backfill_item.event_id,
+ connected_insertion_event_backfill_item.type,
+ )
+ )
+
+ # Second, we need to go and try to find any batch events connected
+ # to a given insertion event (by batch_id). If we find any, we'll
+ # add them to the queue and navigate up the DAG like normal in the
+ # next iteration of the loop.
+ if event_type == EventTypes.MSC2716_INSERTION:
+ connected_batch_event_backfill_results = (
+ await self.store.get_connected_batch_event_backfill_results(
+ event_id, limit - len(event_id_results)
+ )
+ )
+ logger.debug(
+ "_get_backfill_events: connected_batch_event_backfill_results %s",
+ connected_batch_event_backfill_results,
+ )
+ for (
+ connected_batch_event_backfill_item
+ ) in connected_batch_event_backfill_results:
+ if (
+ connected_batch_event_backfill_item.event_id
+ not in event_id_results
+ ):
+ queue.put(
+ (
+ -connected_batch_event_backfill_item.depth,
+ -connected_batch_event_backfill_item.stream_ordering,
+ connected_batch_event_backfill_item.event_id,
+ connected_batch_event_backfill_item.type,
+ )
+ )
+
+ # Now we just look up the DAG by prev_events as normal
+ connected_prev_event_backfill_results = (
+ await self.store.get_connected_prev_event_backfill_results(
+ event_id, limit - len(event_id_results)
+ )
+ )
+ logger.debug(
+ "_get_backfill_events: prev_event_ids %s",
+ connected_prev_event_backfill_results,
+ )
+ for (
+ connected_prev_event_backfill_item
+ ) in connected_prev_event_backfill_results:
+ if connected_prev_event_backfill_item.event_id not in event_id_results:
+ queue.put(
+ (
+ -connected_prev_event_backfill_item.depth,
+ -connected_prev_event_backfill_item.stream_ordering,
+ connected_prev_event_backfill_item.event_id,
+ connected_prev_event_backfill_item.type,
+ )
+ )
+
+ events = await self.store.get_events_as_list(event_id_results)
+ return sorted(
+ events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
+ )
+
@log_function
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
@@ -1053,6 +1183,34 @@ class FederationHandler:
limit = min(limit, 100)
events = await self.store.get_backfill_events(room_id, pdu_list, limit)
+ logger.info(
+ "old implementation backfill events=%s",
+ [
+ "event_id=%s,depth=%d,body=%s,prevs=%s\n"
+ % (
+ event.event_id,
+ event.depth,
+ event.content.get("body", event.type),
+ event.prev_event_ids(),
+ )
+ for event in events
+ ],
+ )
+
+ events = await self.get_backfill_events(room_id, pdu_list, limit)
+ logger.info(
+ "new implementation backfill events=%s",
+ [
+ "event_id=%s,depth=%d,body=%s,prevs=%s\n"
+ % (
+ event.event_id,
+ event.depth,
+ event.content.get("body", event.type),
+ event.prev_event_ids(),
+ )
+ for event in events
+ ],
+ )
events = await filter_events_for_server(self.storage, origin, events)
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index ab2ed53bce..c9060a594f 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -702,38 +702,38 @@ class FederationEventHandler:
event.event_id
)
- # Maybe we can get lucky and save ourselves a lookup
- # by checking the events in the backfill first
- insertion_event = event_map[
- insertion_event_id
- ] or await self._store.get_event(
- insertion_event_id, allow_none=True
- )
-
- if insertion_event:
- # Connect the insertion events' `prev_event` successors
- # via fake edges pointing to the insertion event itself
- # so the insertion event sorts topologically
- # behind-in-time the successor. Nestled perfectly
- # between the prev_event and the successor.
- for insertion_prev_event_id in insertion_event.prev_event_ids():
- successor_event_ids = successor_event_id_map[
- insertion_prev_event_id
- ]
- logger.info(
- "insertion_event_id=%s successor_event_ids=%s",
- insertion_event_id,
- successor_event_ids,
- )
- if successor_event_ids:
- for successor_event_id in successor_event_ids:
- # Don't add itself back as a successor
- if successor_event_id != insertion_event_id:
- # Fake edge to point the successor back
- # at the insertion event
- event_id_graph.setdefault(
- successor_event_id, []
- ).append(insertion_event_id)
+ # # Maybe we can get lucky and save ourselves a lookup
+ # # by checking the events in the backfill first
+ # insertion_event = event_map[
+ # insertion_event_id
+ # ] or await self._store.get_event(
+ # insertion_event_id, allow_none=True
+ # )
+
+ # if insertion_event:
+ # # Connect the insertion events' `prev_event` successors
+ # # via fake edges pointing to the insertion event itself
+ # # so the insertion event sorts topologically
+ # # behind-in-time the successor. Nestled perfectly
+ # # between the prev_event and the successor.
+ # for insertion_prev_event_id in insertion_event.prev_event_ids():
+ # successor_event_ids = successor_event_id_map[
+ # insertion_prev_event_id
+ # ]
+ # logger.info(
+ # "insertion_event_id=%s successor_event_ids=%s",
+ # insertion_event_id,
+ # successor_event_ids,
+ # )
+ # if successor_event_ids:
+ # for successor_event_id in successor_event_ids:
+ # # Don't add itself back as a successor
+ # if successor_event_id != insertion_event_id:
+ # # Fake edge to point the successor back
+ # # at the insertion event
+ # event_id_graph.setdefault(
+ # successor_event_id, []
+ # ).append(insertion_event_id)
# TODO: We also need to add fake edges to connect the oldest-in-time messages
# in the batch to the event we branched off of, see https://github.com/matrix-org/synapse/pull/11114#discussion_r739300985
@@ -773,17 +773,17 @@ class FederationEventHandler:
# We want to sort these by depth so we process them and
# tell clients about them in order.
- # sorted_events = sorted(events, key=lambda x: x.depth)
-
- # We want to sort topologically so we process them and tell clients
- # about them in order.
- sorted_events = []
- event_ids = [event.event_id for event in events]
- event_map = {event.event_id: event for event in events}
- event_id_graph = await self.generateEventIdGraphFromEvents(events)
- for event_id in sorted_topologically(event_ids, event_id_graph):
- sorted_events.append(event_map[event_id])
- sorted_events = reversed(sorted_events)
+ sorted_events = sorted(events, key=lambda x: x.depth)
+
+ # # We want to sort topologically so we process them and tell clients
+ # # about them in order.
+ # sorted_events = []
+ # event_ids = [event.event_id for event in events]
+ # event_map = {event.event_id: event for event in events}
+ # event_id_graph = await self.generateEventIdGraphFromEvents(events)
+ # for event_id in sorted_topologically(event_ids, event_id_graph):
+ # sorted_events.append(event_map[event_id])
+ # sorted_events = reversed(sorted_events)
logger.info(
"backfill sorted_events=%s",
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4a4d35f77c..a569e8146a 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,7 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple, NamedTuple
from prometheus_client import Counter, Gauge
@@ -53,6 +53,14 @@ pdus_pruned_from_federation_queue = Counter(
logger = logging.getLogger(__name__)
+# All the info we need while iterating the DAG while backfilling
+class BackfillQueueNavigationItem(NamedTuple):
+ depth: int
+ stream_ordering: int
+ event_id: str
+ type: str
+
+
class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
@@ -987,6 +995,117 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
+ async def get_connected_insertion_event_backfill_results(
+ self, event_id: str, limit: int
+ ) -> List[BackfillQueueNavigationItem]:
+ def _get_connected_insertion_event_backfill_results_txn(txn):
+ # Look for the "insertion" events connected to the given event_id
+ connected_insertion_event_query = """
+ SELECT e.depth, e.stream_ordering, i.event_id, e.type FROM insertion_event_edges AS i
+ /* Get the depth of the insertion event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ /* Find an insertion event which points via prev_events to the given event_id */
+ WHERE i.insertion_prev_event_id = ?
+ LIMIT ?
+ """
+
+ txn.execute(
+ connected_insertion_event_query,
+ (event_id, limit),
+ )
+ connected_insertion_event_id_results = txn.fetchall()
+ return [
+ BackfillQueueNavigationItem(
+ depth=row[0],
+ stream_ordering=row[1],
+ event_id=row[2],
+ type=row[3],
+ )
+ for row in connected_insertion_event_id_results
+ ]
+
+ return await self.db_pool.runInteraction(
+ "get_connected_insertion_event_backfill_results",
+ _get_connected_insertion_event_backfill_results_txn,
+ )
+
+ async def get_connected_batch_event_backfill_results(
+ self, insertion_event_id: str, limit: int
+ ) -> List[BackfillQueueNavigationItem]:
+ def _get_connected_batch_event_backfill_results_txn(txn):
+ # Find any batch connections of a given insertion event
+ batch_connection_query = """
+ SELECT e.depth, e.stream_ordering, c.event_id, e.type FROM insertion_events AS i
+ /* Find the batch that connects to the given insertion event */
+ INNER JOIN batch_events AS c
+ ON i.next_batch_id = c.batch_id
+ /* Get the depth of the batch start event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ /* Find an insertion event which matches the given event_id */
+ WHERE i.event_id = ?
+ LIMIT ?
+ """
+
+ # Find any batch connections for the given insertion event
+ txn.execute(
+ batch_connection_query,
+ (insertion_event_id, limit),
+ )
+ batch_start_event_id_results = txn.fetchall()
+ return [
+ BackfillQueueNavigationItem(
+ depth=row[0],
+ stream_ordering=row[1],
+ event_id=row[2],
+ type=row[3],
+ )
+ for row in batch_start_event_id_results
+ ]
+
+ return await self.db_pool.runInteraction(
+ "get_connected_batch_event_backfill_results",
+ _get_connected_batch_event_backfill_results_txn,
+ )
+
+ async def get_connected_prev_event_backfill_results(
+ self, event_id: str, limit: int
+ ) -> List[BackfillQueueNavigationItem]:
+ def _get_connected_prev_event_backfill_results_txn(txn):
+ # Look for the prev_event_id connected to the given event_id
+ connected_prev_event_query = """
+ SELECT depth, stream_ordering, prev_event_id, events.type FROM event_edges
+ /* Get the depth and stream_ordering of the prev_event_id from the events table */
+ INNER JOIN events
+ ON prev_event_id = events.event_id
+ /* Look for an edge which matches the given event_id */
+ WHERE event_edges.event_id = ?
+ AND event_edges.is_state = ?
+ /* Because we can have many events at the same depth,
+ * we want to also tie-break and sort on stream_ordering */
+ ORDER BY depth DESC, stream_ordering DESC
+ LIMIT ?
+ """
+
+ txn.execute(
+ connected_prev_event_query,
+ (event_id, False, limit),
+ )
+ prev_event_id_results = txn.fetchall()
+ return [
+ BackfillQueueNavigationItem(
+ depth=row[0],
+ stream_ordering=row[1],
+ event_id=row[2],
+ type=row[3],
+ )
+ for row in prev_event_id_results
+ ]
+
+ return await self.db_pool.runInteraction(
+ "get_connected_prev_event_backfill_results",
+ _get_connected_prev_event_backfill_results_txn,
+ )
+
async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
|