diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c478e0bc5c..e28c74daf0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -16,6 +16,7 @@
"""Contains handlers for federation events."""
import logging
+from queue import Empty, PriorityQueue
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
@@ -1041,6 +1042,135 @@ class FederationHandler:
else:
return []
+ async def get_backfill_events(
+ self, room_id: str, event_id_list: list, limit: int
+ ) -> List[EventBase]:
+ event_id_results = set()
+
+ # In a PriorityQueue, the lowest valued entries are retrieved first.
+ # We're using depth as the priority in the queue and tie-break based on
+ # stream_ordering. Depth is lowest at the oldest-in-time message and
+ # highest and newest-in-time message. We add events to the queue with a
+ # negative depth so that we process the newest-in-time messages first
+ # going backwards in time. stream_ordering follows the same pattern.
+ queue = PriorityQueue()
+
+ seed_events = await self.store.get_events_as_list(event_id_list)
+ for seed_event in seed_events:
+ # Make sure the seed event actually pertains to this room. We also
+ # need to make sure the depth is available since our whole DAG
+ # navigation here depends on depth.
+ if seed_event.room_id == room_id and seed_event.depth:
+ queue.put(
+ (
+ -seed_event.depth,
+ -seed_event.internal_metadata.stream_ordering,
+ seed_event.event_id,
+ seed_event.type,
+ )
+ )
+
+ while not queue.empty() and len(event_id_results) < limit:
+ try:
+ _, _, event_id, event_type = queue.get_nowait()
+ except Empty:
+ break
+
+ if event_id in event_id_results:
+ continue
+
+ event_id_results.add(event_id)
+
+ if self.hs.config.experimental.msc2716_enabled:
+ # Try and find any potential historical batches of message history.
+ #
+ # First we look for an insertion event connected to the current
+ # event (by prev_event). If we find any, we'll add them to the queue
+ # and navigate up the DAG like normal in the next iteration of the
+ # loop.
+ connected_insertion_event_backfill_results = (
+ await self.store.get_connected_insertion_event_backfill_results(
+ event_id, limit - len(event_id_results)
+ )
+ )
+ logger.debug(
+ "_get_backfill_events: connected_insertion_event_backfill_results=%s",
+ connected_insertion_event_backfill_results,
+ )
+ for (
+ connected_insertion_event_backfill_item
+ ) in connected_insertion_event_backfill_results:
+ if (
+ connected_insertion_event_backfill_item.event_id
+ not in event_id_results
+ ):
+ queue.put(
+ (
+ -connected_insertion_event_backfill_item.depth,
+ -connected_insertion_event_backfill_item.stream_ordering,
+ connected_insertion_event_backfill_item.event_id,
+ connected_insertion_event_backfill_item.type,
+ )
+ )
+
+ # Second, we need to go and try to find any batch events connected
+ # to a given insertion event (by batch_id). If we find any, we'll
+ # add them to the queue and navigate up the DAG like normal in the
+ # next iteration of the loop.
+ if event_type == EventTypes.MSC2716_INSERTION:
+ connected_batch_event_backfill_results = (
+ await self.store.get_connected_batch_event_backfill_results(
+ event_id, limit - len(event_id_results)
+ )
+ )
+ logger.debug(
+ "_get_backfill_events: connected_batch_event_backfill_results %s",
+ connected_batch_event_backfill_results,
+ )
+ for (
+ connected_batch_event_backfill_item
+ ) in connected_batch_event_backfill_results:
+ if (
+ connected_batch_event_backfill_item.event_id
+ not in event_id_results
+ ):
+ queue.put(
+ (
+ -connected_batch_event_backfill_item.depth,
+ -connected_batch_event_backfill_item.stream_ordering,
+ connected_batch_event_backfill_item.event_id,
+ connected_batch_event_backfill_item.type,
+ )
+ )
+
+ # Now we just look up the DAG by prev_events as normal
+ connected_prev_event_backfill_results = (
+ await self.store.get_connected_prev_event_backfill_results(
+ event_id, limit - len(event_id_results)
+ )
+ )
+ logger.debug(
+ "_get_backfill_events: prev_event_ids %s",
+ connected_prev_event_backfill_results,
+ )
+ for (
+ connected_prev_event_backfill_item
+ ) in connected_prev_event_backfill_results:
+ if connected_prev_event_backfill_item.event_id not in event_id_results:
+ queue.put(
+ (
+ -connected_prev_event_backfill_item.depth,
+ -connected_prev_event_backfill_item.stream_ordering,
+ connected_prev_event_backfill_item.event_id,
+ connected_prev_event_backfill_item.type,
+ )
+ )
+
+ events = await self.store.get_events_as_list(event_id_results)
+ return sorted(
+ events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
+ )
+
@log_function
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
@@ -1053,6 +1183,34 @@ class FederationHandler:
limit = min(limit, 100)
events = await self.store.get_backfill_events(room_id, pdu_list, limit)
+ logger.info(
+ "old implementation backfill events=%s",
+ [
+ "event_id=%s,depth=%d,body=%s,prevs=%s\n"
+ % (
+ event.event_id,
+ event.depth,
+ event.content.get("body", event.type),
+ event.prev_event_ids(),
+ )
+ for event in events
+ ],
+ )
+
+ events = await self.get_backfill_events(room_id, pdu_list, limit)
+ logger.info(
+ "new implementation backfill events=%s",
+ [
+ "event_id=%s,depth=%d,body=%s,prevs=%s\n"
+ % (
+ event.event_id,
+ event.depth,
+ event.content.get("body", event.type),
+ event.prev_event_ids(),
+ )
+ for event in events
+ ],
+ )
events = await filter_events_for_server(self.storage, origin, events)
|