summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11114.bugfix1
-rwxr-xr-xscripts-dev/complement.sh2
-rw-r--r--synapse/api/room_versions.py31
-rw-r--r--synapse/handlers/federation.py27
-rw-r--r--synapse/handlers/federation_event.py34
-rw-r--r--synapse/handlers/message.py23
-rw-r--r--synapse/handlers/room_batch.py12
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/rest/client/room_batch.py8
-rw-r--r--synapse/storage/databases/main/event_federation.py324
-rw-r--r--synapse/storage/databases/main/events.py13
11 files changed, 348 insertions, 129 deletions
diff --git a/changelog.d/11114.bugfix b/changelog.d/11114.bugfix
new file mode 100644
index 0000000000..c6e65df97f
--- /dev/null
+++ b/changelog.d/11114.bugfix
@@ -0,0 +1 @@
+Fix [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) historical messages backfilling in random order on remote homeservers.
diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
index 7d38b39e90..927c69a753 100755
--- a/scripts-dev/complement.sh
+++ b/scripts-dev/complement.sh
@@ -65,4 +65,4 @@ if [[ -n "$1" ]]; then
 fi
 
 # Run the tests!
-go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403,msc2716 -count=1 "${EXTRA_COMPLEMENT_ARGS[@]}" ./tests/...
+go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403,msc2716 -count=1 "${EXTRA_COMPLEMENT_ARGS[@]}" ./tests
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 0a895bba48..7ef4701b13 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -81,6 +81,8 @@ class RoomVersion:
     msc2716_historical = attr.ib(type=bool)
     # MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events
     msc2716_redactions = attr.ib(type=bool)
+    # MSC2716: Adds support for events with no `prev_events` but with some `auth_events`
+    msc2716_empty_prev_events = attr.ib(type=bool)
 
 
 class RoomVersions:
@@ -99,6 +101,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc2716_empty_prev_events=False,
     )
     V2 = RoomVersion(
         "2",
@@ -115,6 +118,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc2716_empty_prev_events=False,
     )
     V3 = RoomVersion(
         "3",
@@ -131,6 +135,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc2716_empty_prev_events=False,
     )
     V4 = RoomVersion(
         "4",
@@ -147,6 +152,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc2716_empty_prev_events=False,
     )
     V5 = RoomVersion(
         "5",
@@ -163,6 +169,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc2716_empty_prev_events=False,
     )
     V6 = RoomVersion(
         "6",
@@ -179,6 +186,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc2716_empty_prev_events=False,
     )
     MSC2176 = RoomVersion(
         "org.matrix.msc2176",
@@ -195,6 +203,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc2716_empty_prev_events=False,
     )
     V7 = RoomVersion(
         "7",
@@ -211,6 +220,7 @@ class RoomVersions:
         msc2403_knocking=True,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc2716_empty_prev_events=False,
     )
     V8 = RoomVersion(
         "8",
@@ -227,6 +237,7 @@ class RoomVersions:
         msc2403_knocking=True,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc2716_empty_prev_events=False,
     )
     V9 = RoomVersion(
         "9",
@@ -243,6 +254,7 @@ class RoomVersions:
         msc2403_knocking=True,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc2716_empty_prev_events=False,
     )
     MSC2716v3 = RoomVersion(
         "org.matrix.msc2716v3",
@@ -259,6 +271,24 @@ class RoomVersions:
         msc2403_knocking=True,
         msc2716_historical=True,
         msc2716_redactions=True,
+        msc2716_empty_prev_events=False,
+    )
+    MSC2716v4 = RoomVersion(
+        "org.matrix.msc2716v4",
+        RoomDisposition.UNSTABLE,
+        EventFormatVersions.V3,
+        StateResolutionVersions.V2,
+        enforce_key_validity=True,
+        special_case_aliases_auth=False,
+        strict_canonicaljson=True,
+        limit_notifications_power_levels=True,
+        msc2176_redaction_rules=False,
+        msc3083_join_rules=False,
+        msc3375_redaction_rules=False,
+        msc2403_knocking=True,
+        msc2716_historical=True,
+        msc2716_redactions=True,
+        msc2716_empty_prev_events=True,
     )
 
 
@@ -276,6 +306,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
         RoomVersions.V8,
         RoomVersions.V9,
         RoomVersions.MSC2716v3,
+        RoomVersions.MSC2716v4,
     )
 }
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3112cc88b1..3f5d1d701f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -136,9 +136,14 @@ class FederationHandler:
         oldest_events_with_depth = (
             await self.store.get_oldest_event_ids_with_depth_in_room(room_id)
         )
-        insertion_events_to_be_backfilled = (
-            await self.store.get_insertion_event_backwards_extremities_in_room(room_id)
-        )
+
+        insertion_events_to_be_backfilled: Dict[str, int] = {}
+        if self.hs.config.experimental.msc2716_enabled:
+            insertion_events_to_be_backfilled = (
+                await self.store.get_insertion_event_backward_extremities_in_room(
+                    room_id
+                )
+            )
         logger.debug(
             "_maybe_backfill_inner: extremities oldest_events_with_depth=%s insertion_events_to_be_backfilled=%s",
             oldest_events_with_depth,
@@ -241,11 +246,12 @@ class FederationHandler:
         ]
 
         logger.debug(
-            "room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems: %s filtered_sorted_extremeties_tuple: %s",
+            "room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems (%d): %s filtered_sorted_extremeties_tuple: %s",
             room_id,
             current_depth,
             limit,
             max_depth,
+            len(sorted_extremeties_tuple),
             sorted_extremeties_tuple,
             filtered_sorted_extremeties_tuple,
         )
@@ -1047,6 +1053,19 @@ class FederationHandler:
         limit = min(limit, 100)
 
         events = await self.store.get_backfill_events(room_id, pdu_list, limit)
+        logger.debug(
+            "on_backfill_request: backfill events=%s",
+            [
+                "event_id=%s,depth=%d,body=%s,prevs=%s\n"
+                % (
+                    event.event_id,
+                    event.depth,
+                    event.content.get("body", event.type),
+                    event.prev_event_ids(),
+                )
+                for event in events
+            ],
+        )
 
         events = await filter_events_for_server(self.storage, origin, events)
 
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 1a1cd93b1a..afc1de9894 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -514,7 +514,11 @@ class FederationEventHandler:
                     f"room {ev.room_id}, when we were backfilling in {room_id}"
                 )
 
-        await self._process_pulled_events(dest, events, backfilled=True)
+        await self._process_pulled_events(
+            dest,
+            events,
+            backfilled=True,
+        )
 
     async def _get_missing_events_for_pdu(
         self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
@@ -632,11 +636,24 @@ class FederationEventHandler:
             backfilled: True if this is part of a historical batch of events (inhibits
                 notification to clients, and validation of device keys.)
         """
+        logger.debug(
+            "processing pulled backfilled=%s events=%s",
+            backfilled,
+            [
+                "event_id=%s,depth=%d,body=%s,prevs=%s\n"
+                % (
+                    event.event_id,
+                    event.depth,
+                    event.content.get("body", event.type),
+                    event.prev_event_ids(),
+                )
+                for event in events
+            ],
+        )
 
         # We want to sort these by depth so we process them and
         # tell clients about them in order.
         sorted_events = sorted(events, key=lambda x: x.depth)
-
         for ev in sorted_events:
             with nested_logging_context(ev.event_id):
                 await self._process_pulled_event(origin, ev, backfilled=backfilled)
@@ -998,6 +1015,8 @@ class FederationEventHandler:
 
         await self._run_push_actions_and_persist_event(event, context, backfilled)
 
+        await self._handle_marker_event(origin, event)
+
         if backfilled or context.rejected:
             return
 
@@ -1077,8 +1096,6 @@ class FederationEventHandler:
                     event.sender,
                 )
 
-        await self._handle_marker_event(origin, event)
-
     async def _resync_device(self, sender: str) -> None:
         """We have detected that the device list for the given user may be out
         of sync, so we try and resync them.
@@ -1325,7 +1342,14 @@ class FederationEventHandler:
             return event, context
 
         events_to_persist = (x for x in (prep(event) for event in fetched_events) if x)
-        await self.persist_events_and_notify(room_id, tuple(events_to_persist))
+        await self.persist_events_and_notify(
+            room_id,
+            tuple(events_to_persist),
+            # Mark these events backfilled as they're historic events that will
+            # eventually be backfilled. For example, missing events we fetch
+            # during backfill should be marked as backfilled as well.
+            backfilled=True,
+        )
 
     async def _check_event_auth(
         self,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d4c2a6ab7a..916ba0662d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -949,14 +949,24 @@ class EventCreationHandler:
         else:
             prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
 
-        # we now ought to have some prev_events (unless it's a create event).
-        #
-        # do a quick sanity check here, rather than waiting until we've created the
+        # Do a quick sanity check here, rather than waiting until we've created the
         # event and then try to auth it (which fails with a somewhat confusing "No
         # create event in auth events")
-        assert (
-            builder.type == EventTypes.Create or len(prev_event_ids) > 0
-        ), "Attempting to create an event with no prev_events"
+        room_version_obj = await self.store.get_room_version(builder.room_id)
+        if room_version_obj.msc2716_empty_prev_events:
+            # We allow events with no `prev_events` but it better have some `auth_events`
+            assert (
+                builder.type == EventTypes.Create
+                or len(prev_event_ids) > 0
+                # Allow an event to have empty list of prev_event_ids
+                # only if it has auth_event_ids.
+                or (auth_event_ids and len(auth_event_ids) > 0)
+            ), "Attempting to create an event with no prev_events or auth_event_ids"
+        else:
+            # we now ought to have some prev_events (unless it's a create event).
+            assert (
+                builder.type == EventTypes.Create or len(prev_event_ids) > 0
+            ), "Attempting to create an event with no prev_events"
 
         event = await builder.build(
             prev_event_ids=prev_event_ids,
@@ -1504,6 +1514,7 @@ class EventCreationHandler:
                 next_batch_id = event.content.get(
                     EventContentFields.MSC2716_NEXT_BATCH_ID
                 )
+
                 conflicting_insertion_event_id = None
                 if next_batch_id:
                     conflicting_insertion_event_id = (
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 0723286383..2e31532389 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -13,10 +13,6 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-def generate_fake_event_id() -> str:
-    return "$fake_" + random_string(43)
-
-
 class RoomBatchHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
@@ -184,7 +180,7 @@ class RoomBatchHandler:
 
         # Make the state events float off on their own so we don't have a
         # bunch of `@mxid joined the room` noise between each batch
-        prev_event_id_for_state_chain = generate_fake_event_id()
+        prev_event_ids_for_state_chain: List[str] = []
 
         for state_event in state_events_at_start:
             assert_params_in_dict(
@@ -221,7 +217,7 @@ class RoomBatchHandler:
                     action=membership,
                     content=event_dict["content"],
                     outlier=True,
-                    prev_event_ids=[prev_event_id_for_state_chain],
+                    prev_event_ids=prev_event_ids_for_state_chain,
                     # Make sure to use a copy of this list because we modify it
                     # later in the loop here. Otherwise it will be the same
                     # reference and also update in the event when we append later.
@@ -240,7 +236,7 @@ class RoomBatchHandler:
                     ),
                     event_dict,
                     outlier=True,
-                    prev_event_ids=[prev_event_id_for_state_chain],
+                    prev_event_ids=prev_event_ids_for_state_chain,
                     # Make sure to use a copy of this list because we modify it
                     # later in the loop here. Otherwise it will be the same
                     # reference and also update in the event when we append later.
@@ -251,7 +247,7 @@ class RoomBatchHandler:
             state_event_ids_at_start.append(event_id)
             auth_event_ids.append(event_id)
             # Connect all the state in a floating chain
-            prev_event_id_for_state_chain = event_id
+            prev_event_ids_for_state_chain = [event_id]
 
         return state_event_ids_at_start
 
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 08244b690d..343e3c6b7b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -644,7 +644,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if block_invite:
                 raise SynapseError(403, "Invites have been disabled on this server")
 
-        if prev_event_ids:
+        if prev_event_ids is not None:
             return await self._local_membership_update(
                 requester=requester,
                 target=target,
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index e4c9451ae0..c9509d2ae3 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -131,6 +131,14 @@ class RoomBatchSendEventRestServlet(RestServlet):
             prev_event_ids_from_query
         )
 
+        if not auth_event_ids:
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist."
+                % prev_event_ids_from_query,
+                errcode=Codes.INVALID_PARAM,
+            )
+
         state_event_ids_at_start = []
         # Create and persist all of the state events that float off on their own
         # before the batch. These will most likely be all of the invite/member
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ef5d1ef01e..c58f7dd009 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,11 +14,21 @@
 import itertools
 import logging
 from queue import Empty, PriorityQueue
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    NamedTuple,
+    Optional,
+    Set,
+    Tuple,
+)
 
 from prometheus_client import Counter, Gauge
 
-from synapse.api.constants import MAX_DEPTH
+from synapse.api.constants import MAX_DEPTH, EventTypes
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
@@ -56,6 +66,14 @@ pdus_pruned_from_federation_queue = Counter(
 logger = logging.getLogger(__name__)
 
 
+# All the info we need while iterating the DAG while backfilling
+class BackfillQueueNavigationItem(NamedTuple):
+    depth: int
+    stream_ordering: int
+    event_id: str
+    type: str
+
+
 class _NoChainCoverIndex(Exception):
     def __init__(self, room_id: str):
         super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
@@ -65,6 +83,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
+        self.hs = hs
+
         if hs.config.worker.run_background_tasks:
             hs.get_clock().looping_call(
                 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
@@ -728,7 +748,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             room_id,
         )
 
-    async def get_insertion_event_backwards_extremities_in_room(
+    async def get_insertion_event_backward_extremities_in_room(
         self, room_id
     ) -> Dict[str, int]:
         """Get the insertion events we know about that we haven't backfilled yet.
@@ -745,7 +765,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             Map from event_id to depth
         """
 
-        def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id):
+        def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
             sql = """
                 SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
                 /* We only want insertion events that are also marked as backwards extremities */
@@ -761,8 +781,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             return dict(txn)
 
         return await self.db_pool.runInteraction(
-            "get_insertion_event_backwards_extremities_in_room",
-            get_insertion_event_backwards_extremities_in_room_txn,
+            "get_insertion_event_backward_extremities_in_room",
+            get_insertion_event_backward_extremities_in_room_txn,
             room_id,
         )
 
@@ -988,143 +1008,243 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
         )
 
-    async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
-        """Get a list of Events for a given topic that occurred before (and
-        including) the events in event_list. Return a list of max size `limit`
+    def _get_connected_batch_event_backfill_results_txn(
+        self, txn: LoggingTransaction, insertion_event_id: str, limit: int
+    ) -> List[BackfillQueueNavigationItem]:
+        """
+        Find any batch connections of a given insertion event.
+        A batch event points at a insertion event via:
+        batch_event.content[MSC2716_BATCH_ID] -> insertion_event.content[MSC2716_NEXT_BATCH_ID]
 
         Args:
-            room_id
-            event_list
-            limit
+            txn: The database transaction to use
+            insertion_event_id: The event ID to navigate from. We will find
+                batch events that point back at this insertion event.
+            limit: Max number of event ID's to query for and return
+
+        Returns:
+            List of batch events that the backfill queue can process
+        """
+        batch_connection_query = """
+            SELECT e.depth, e.stream_ordering, c.event_id, e.type FROM insertion_events AS i
+            /* Find the batch that connects to the given insertion event */
+            INNER JOIN batch_events AS c
+            ON i.next_batch_id = c.batch_id
+            /* Get the depth of the batch start event from the events table */
+            INNER JOIN events AS e USING (event_id)
+            /* Find an insertion event which matches the given event_id */
+            WHERE i.event_id = ?
+            LIMIT ?
         """
-        event_ids = await self.db_pool.runInteraction(
-            "get_backfill_events",
-            self._get_backfill_events,
-            room_id,
-            event_list,
-            limit,
-        )
-        events = await self.get_events_as_list(event_ids)
-        return sorted(events, key=lambda e: -e.depth)
 
-    def _get_backfill_events(self, txn, room_id, event_list, limit):
-        logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
+        # Find any batch connections for the given insertion event
+        txn.execute(
+            batch_connection_query,
+            (insertion_event_id, limit),
+        )
+        batch_start_event_id_results = txn.fetchall()
+        return [
+            BackfillQueueNavigationItem(
+                depth=row[0],
+                stream_ordering=row[1],
+                event_id=row[2],
+                type=row[3],
+            )
+            for row in batch_start_event_id_results
+        ]
 
-        event_results = set()
+    def _get_connected_prev_event_backfill_results_txn(
+        self, txn: LoggingTransaction, event_id: str, limit: int
+    ) -> List[BackfillQueueNavigationItem]:
+        """
+        Find any events connected by prev_event the specified event_id.
 
-        # We want to make sure that we do a breadth-first, "depth" ordered
-        # search.
+        Args:
+            txn: The database transaction to use
+            event_id: The event ID to navigate from
+            limit: Max number of event ID's to query for and return
 
+        Returns:
+            List of prev events that the backfill queue can process
+        """
         # Look for the prev_event_id connected to the given event_id
-        query = """
-            SELECT depth, prev_event_id FROM event_edges
-            /* Get the depth of the prev_event_id from the events table */
+        connected_prev_event_query = """
+            SELECT depth, stream_ordering, prev_event_id, events.type FROM event_edges
+            /* Get the depth and stream_ordering of the prev_event_id from the events table */
             INNER JOIN events
             ON prev_event_id = events.event_id
-            /* Find an event which matches the given event_id */
+            /* Look for an edge which matches the given event_id */
             WHERE event_edges.event_id = ?
             AND event_edges.is_state = ?
+            /* Because we can have many events at the same depth,
+            * we want to also tie-break and sort on stream_ordering */
+            ORDER BY depth DESC, stream_ordering DESC
             LIMIT ?
         """
 
-        # Look for the "insertion" events connected to the given event_id
-        connected_insertion_event_query = """
-            SELECT e.depth, i.event_id FROM insertion_event_edges AS i
-            /* Get the depth of the insertion event from the events table */
-            INNER JOIN events AS e USING (event_id)
-            /* Find an insertion event which points via prev_events to the given event_id */
-            WHERE i.insertion_prev_event_id = ?
-            LIMIT ?
+        txn.execute(
+            connected_prev_event_query,
+            (event_id, False, limit),
+        )
+        prev_event_id_results = txn.fetchall()
+        return [
+            BackfillQueueNavigationItem(
+                depth=row[0],
+                stream_ordering=row[1],
+                event_id=row[2],
+                type=row[3],
+            )
+            for row in prev_event_id_results
+        ]
+
+    async def get_backfill_events(
+        self, room_id: str, seed_event_id_list: list, limit: int
+    ):
+        """Get a list of Events for a given topic that occurred before (and
+        including) the events in seed_event_id_list. Return a list of max size `limit`
+
+        Args:
+            room_id
+            seed_event_id_list
+            limit
         """
+        event_ids = await self.db_pool.runInteraction(
+            "get_backfill_events",
+            self._get_backfill_events,
+            room_id,
+            seed_event_id_list,
+            limit,
+        )
+        events = await self.get_events_as_list(event_ids)
+        return sorted(
+            events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
+        )
 
-        # Find any batch connections of a given insertion event
-        batch_connection_query = """
-            SELECT e.depth, c.event_id FROM insertion_events AS i
-            /* Find the batch that connects to the given insertion event */
-            INNER JOIN batch_events AS c
-            ON i.next_batch_id = c.batch_id
-            /* Get the depth of the batch start event from the events table */
-            INNER JOIN events AS e USING (event_id)
-            /* Find an insertion event which matches the given event_id */
-            WHERE i.event_id = ?
-            LIMIT ?
+    def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
         """
+        We want to make sure that we do a breadth-first, "depth" ordered search.
+        We also handle navigating historical branches of history connected by
+        insertion and batch events.
+        """
+        logger.debug(
+            "_get_backfill_events(room_id=%s): seeding backfill with seed_event_id_list=%s limit=%s",
+            room_id,
+            seed_event_id_list,
+            limit,
+        )
+
+        event_id_results = set()
 
         # In a PriorityQueue, the lowest valued entries are retrieved first.
-        # We're using depth as the priority in the queue.
-        # Depth is lowest at the oldest-in-time message and highest and
-        # newest-in-time message. We add events to the queue with a negative depth so that
-        # we process the newest-in-time messages first going backwards in time.
+        # We're using depth as the priority in the queue and tie-break based on
+        # stream_ordering. Depth is lowest at the oldest-in-time message and
+        # highest and newest-in-time message. We add events to the queue with a
+        # negative depth so that we process the newest-in-time messages first
+        # going backwards in time. stream_ordering follows the same pattern.
         queue = PriorityQueue()
 
-        for event_id in event_list:
-            depth = self.db_pool.simple_select_one_onecol_txn(
+        for seed_event_id in seed_event_id_list:
+            event_lookup_result = self.db_pool.simple_select_one_txn(
                 txn,
                 table="events",
-                keyvalues={"event_id": event_id, "room_id": room_id},
-                retcol="depth",
+                keyvalues={"event_id": seed_event_id, "room_id": room_id},
+                retcols=(
+                    "type",
+                    "depth",
+                    "stream_ordering",
+                ),
                 allow_none=True,
             )
 
-            if depth:
-                queue.put((-depth, event_id))
+            logger.debug(
+                "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
+                room_id,
+                seed_event_id,
+                event_lookup_result["depth"],
+                event_lookup_result["stream_ordering"],
+                event_lookup_result["type"],
+            )
+
+            if event_lookup_result["depth"]:
+                queue.put(
+                    (
+                        -event_lookup_result["depth"],
+                        -event_lookup_result["stream_ordering"],
+                        seed_event_id,
+                        event_lookup_result["type"],
+                    )
+                )
 
-        while not queue.empty() and len(event_results) < limit:
+        while not queue.empty() and len(event_id_results) < limit:
             try:
-                _, event_id = queue.get_nowait()
+                _, _, event_id, event_type = queue.get_nowait()
             except Empty:
                 break
 
-            if event_id in event_results:
+            if event_id in event_id_results:
                 continue
 
-            event_results.add(event_id)
+            event_id_results.add(event_id)
 
             # Try and find any potential historical batches of message history.
-            #
-            # First we look for an insertion event connected to the current
-            # event (by prev_event). If we find any, we need to go and try to
-            # find any batch events connected to the insertion event (by
-            # batch_id). If we find any, we'll add them to the queue and
-            # navigate up the DAG like normal in the next iteration of the loop.
-            txn.execute(
-                connected_insertion_event_query, (event_id, limit - len(event_results))
-            )
-            connected_insertion_event_id_results = txn.fetchall()
-            logger.debug(
-                "_get_backfill_events: connected_insertion_event_query %s",
-                connected_insertion_event_id_results,
-            )
-            for row in connected_insertion_event_id_results:
-                connected_insertion_event_depth = row[0]
-                connected_insertion_event = row[1]
-                queue.put((-connected_insertion_event_depth, connected_insertion_event))
+            if self.hs.config.experimental.msc2716_enabled:
+                # We need to go and try to find any batch events connected
+                # to a given insertion event (by batch_id). If we find any, we'll
+                # add them to the queue and navigate up the DAG like normal in the
+                # next iteration of the loop.
+                if event_type == EventTypes.MSC2716_INSERTION:
+                    # Find any batch connections for the given insertion event
+                    connected_batch_event_backfill_results = (
+                        self._get_connected_batch_event_backfill_results_txn(
+                            txn, event_id, limit - len(event_id_results)
+                        )
+                    )
+                    logger.debug(
+                        "_get_backfill_events(room_id=%s): connected_batch_event_backfill_results=%s",
+                        room_id,
+                        connected_batch_event_backfill_results,
+                    )
+                    for (
+                        connected_batch_event_backfill_item
+                    ) in connected_batch_event_backfill_results:
+                        if (
+                            connected_batch_event_backfill_item.event_id
+                            not in event_id_results
+                        ):
+                            queue.put(
+                                (
+                                    -connected_batch_event_backfill_item.depth,
+                                    -connected_batch_event_backfill_item.stream_ordering,
+                                    connected_batch_event_backfill_item.event_id,
+                                    connected_batch_event_backfill_item.type,
+                                )
+                            )
 
-                # Find any batch connections for the given insertion event
-                txn.execute(
-                    batch_connection_query,
-                    (connected_insertion_event, limit - len(event_results)),
-                )
-                batch_start_event_id_results = txn.fetchall()
-                logger.debug(
-                    "_get_backfill_events: batch_start_event_id_results %s",
-                    batch_start_event_id_results,
+            # Now we just look up the DAG by prev_events as normal
+            connected_prev_event_backfill_results = (
+                self._get_connected_prev_event_backfill_results_txn(
+                    txn, event_id, limit - len(event_id_results)
                 )
-                for row in batch_start_event_id_results:
-                    if row[1] not in event_results:
-                        queue.put((-row[0], row[1]))
-
-            txn.execute(query, (event_id, False, limit - len(event_results)))
-            prev_event_id_results = txn.fetchall()
+            )
             logger.debug(
-                "_get_backfill_events: prev_event_ids %s", prev_event_id_results
+                "_get_backfill_events(room_id=%s): connected_prev_event_backfill_results=%s",
+                room_id,
+                connected_prev_event_backfill_results,
             )
+            for (
+                connected_prev_event_backfill_item
+            ) in connected_prev_event_backfill_results:
+                if connected_prev_event_backfill_item.event_id not in event_id_results:
+                    queue.put(
+                        (
+                            -connected_prev_event_backfill_item.depth,
+                            -connected_prev_event_backfill_item.stream_ordering,
+                            connected_prev_event_backfill_item.event_id,
+                            connected_prev_event_backfill_item.type,
+                        )
+                    )
 
-            for row in prev_event_id_results:
-                if row[1] not in event_results:
-                    queue.put((-row[0], row[1]))
-
-        return event_results
+        return event_id_results
 
     async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
         ids = await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 596275c23c..3790a52a89 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2144,9 +2144,14 @@ class PersistEventsStore:
             "   SELECT 1 FROM event_backward_extremities"
             "   WHERE event_id = ? AND room_id = ?"
             " )"
+            # 1. Don't add an event as a extremity again if we already persisted it
+            # as a non-outlier.
+            # 2. Don't add an outlier as an extremity if it has no prev_events
             " AND NOT EXISTS ("
-            "   SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
-            "   AND outlier = ?"
+            "   SELECT 1 FROM events"
+            "   LEFT JOIN event_edges edge"
+            "   ON edge.event_id = events.event_id"
+            "   WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = ? OR edge.event_id IS NULL)"
             " )"
         )
 
@@ -2172,6 +2177,10 @@ class PersistEventsStore:
                 (ev.event_id, ev.room_id)
                 for ev in events
                 if not ev.internal_metadata.is_outlier()
+                # If we encountered an event with no prev_events, then we might
+                # as well remove it now because it won't ever have anything else
+                # to backfill from.
+                or len(ev.prev_event_ids()) == 0
             ],
         )