summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13635.feature1
-rw-r--r--synapse/handlers/federation.py4
-rw-r--r--synapse/storage/databases/main/event_federation.py188
-rw-r--r--tests/storage/test_event_federation.py481
4 files changed, 626 insertions, 48 deletions
diff --git a/changelog.d/13635.feature b/changelog.d/13635.feature
new file mode 100644
index 0000000000..d86bf7ed80
--- /dev/null
+++ b/changelog.d/13635.feature
@@ -0,0 +1 @@
+Exponentially backoff from backfilling the same event over and over.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 583d5ecd77..e1a4265a64 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -226,9 +226,7 @@ class FederationHandler:
         """
         backwards_extremities = [
             _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
-            for event_id, depth in await self.store.get_oldest_event_ids_with_depth_in_room(
-                room_id
-            )
+            for event_id, depth in await self.store.get_backfill_points_in_room(room_id)
         ]
 
         insertion_events_to_be_backfilled: List[_BackfillPoint] = []
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ef477978ed..3251fca6fb 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import datetime
 import itertools
 import logging
 from queue import Empty, PriorityQueue
@@ -43,7 +44,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
-from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -72,6 +73,13 @@ pdus_pruned_from_federation_queue = Counter(
 
 logger = logging.getLogger(__name__)
 
+BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS: int = int(
+    datetime.timedelta(days=7).total_seconds()
+)
+BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS: int = int(
+    datetime.timedelta(hours=1).total_seconds()
+)
+
 
 # All the info we need while iterating the DAG while backfilling
 @attr.s(frozen=True, slots=True, auto_attribs=True)
@@ -715,96 +723,189 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
     @trace
     @tag_args
-    async def get_oldest_event_ids_with_depth_in_room(
-        self, room_id: str
+    async def get_backfill_points_in_room(
+        self,
+        room_id: str,
     ) -> List[Tuple[str, int]]:
-        """Gets the oldest events(backwards extremities) in the room along with the
-        aproximate depth.
-
-        We use this function so that we can compare and see if someones current
-        depth at their current scrollback is within pagination range of the
-        event extremeties. If the current depth is close to the depth of given
-        oldest event, we can trigger a backfill.
+        """
+        Gets the oldest events(backwards extremities) in the room along with the
+        approximate depth. Sorted by depth, highest to lowest (descending).
 
         Args:
             room_id: Room where we want to find the oldest events
 
         Returns:
-            List of (event_id, depth) tuples
+            List of (event_id, depth) tuples. Sorted by depth, highest to lowest
+            (descending)
         """
 
-        def get_oldest_event_ids_with_depth_in_room_txn(
+        def get_backfill_points_in_room_txn(
             txn: LoggingTransaction, room_id: str
         ) -> List[Tuple[str, int]]:
-            # Assemble a dictionary with event_id -> depth for the oldest events
+            # Assemble a tuple lookup of event_id -> depth for the oldest events
             # we know of in the room. Backwards extremeties are the oldest
             # events we know of in the room but we only know of them because
-            # some other event referenced them by prev_event and aren't peristed
-            # in our database yet (meaning we don't know their depth
-            # specifically). So we need to look for the aproximate depth from
+            # some other event referenced them by prev_event and aren't
+            # persisted in our database yet (meaning we don't know their depth
+            # specifically). So we need to look for the approximate depth from
             # the events connected to the current backwards extremeties.
             sql = """
-                SELECT b.event_id, MAX(e.depth) FROM events as e
+                SELECT backward_extrem.event_id, event.depth FROM events AS event
                 /**
                  * Get the edge connections from the event_edges table
                  * so we can see whether this event's prev_events points
                  * to a backward extremity in the next join.
                  */
-                INNER JOIN event_edges as g
-                ON g.event_id = e.event_id
+                INNER JOIN event_edges AS edge
+                ON edge.event_id = event.event_id
                 /**
                  * We find the "oldest" events in the room by looking for
                  * events connected to backwards extremeties (oldest events
                  * in the room that we know of so far).
                  */
-                INNER JOIN event_backward_extremities as b
-                ON g.prev_event_id = b.event_id
-                WHERE b.room_id = ? AND g.is_state is ?
-                GROUP BY b.event_id
+                INNER JOIN event_backward_extremities AS backward_extrem
+                ON edge.prev_event_id = backward_extrem.event_id
+                /**
+                 * We use this info to make sure we don't retry to use a backfill point
+                 * if we've already attempted to backfill from it recently.
+                 */
+                LEFT JOIN event_failed_pull_attempts AS failed_backfill_attempt_info
+                ON
+                    failed_backfill_attempt_info.room_id = backward_extrem.room_id
+                    AND failed_backfill_attempt_info.event_id = backward_extrem.event_id
+                WHERE
+                    backward_extrem.room_id = ?
+                    /* We only care about non-state edges because we used to use
+                     * `event_edges` for two different sorts of "edges" (the current
+                     * event DAG, but also a link to the previous state, for state
+                     * events). These legacy state event edges can be distinguished by
+                     * `is_state` and are removed from the codebase and schema but
+                     * because the schema change is in a background update, it's not
+                     * necessarily safe to assume that it will have been completed.
+                     */
+                    AND edge.is_state is ? /* False */
+                    /**
+                     * Exponential back-off (up to the upper bound) so we don't retry the
+                     * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
+                     *
+                     * We use `1 << n` as a power of 2 equivalent for compatibility
+                     * with older SQLites. The left shift equivalent only works with
+                     * powers of 2 because left shift is a binary operation (base-2).
+                     * Otherwise, we would use `power(2, n)` or the power operator, `2^n`.
+                     */
+                    AND (
+                        failed_backfill_attempt_info.event_id IS NULL
+                        OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */)
+                    )
+                /**
+                 * Sort from highest to the lowest depth. Then tie-break on
+                 * alphabetical order of the event_ids so we get a consistent
+                 * ordering which is nice when asserting things in tests.
+                 */
+                ORDER BY event.depth DESC, backward_extrem.event_id DESC
             """
 
-            txn.execute(sql, (room_id, False))
+            if isinstance(self.database_engine, PostgresEngine):
+                least_function = "least"
+            elif isinstance(self.database_engine, Sqlite3Engine):
+                least_function = "min"
+            else:
+                raise RuntimeError("Unknown database engine")
+
+            txn.execute(
+                sql % (least_function,),
+                (
+                    room_id,
+                    False,
+                    self._clock.time_msec(),
+                    1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS,
+                    1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS,
+                ),
+            )
 
             return cast(List[Tuple[str, int]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
-            "get_oldest_event_ids_with_depth_in_room",
-            get_oldest_event_ids_with_depth_in_room_txn,
+            "get_backfill_points_in_room",
+            get_backfill_points_in_room_txn,
             room_id,
         )
 
     @trace
     async def get_insertion_event_backward_extremities_in_room(
-        self, room_id: str
+        self,
+        room_id: str,
     ) -> List[Tuple[str, int]]:
-        """Get the insertion events we know about that we haven't backfilled yet.
-
-        We use this function so that we can compare and see if someones current
-        depth at their current scrollback is within pagination range of the
-        insertion event. If the current depth is close to the depth of given
-        insertion event, we can trigger a backfill.
+        """
+        Get the insertion events we know about that we haven't backfilled yet
+        along with the approximate depth. Sorted by depth, highest to lowest
+        (descending).
 
         Args:
             room_id: Room where we want to find the oldest events
 
         Returns:
-            List of (event_id, depth) tuples
+            List of (event_id, depth) tuples. Sorted by depth, highest to lowest
+            (descending)
         """
 
         def get_insertion_event_backward_extremities_in_room_txn(
             txn: LoggingTransaction, room_id: str
         ) -> List[Tuple[str, int]]:
             sql = """
-                SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
+                SELECT
+                    insertion_event_extremity.event_id, event.depth
                 /* We only want insertion events that are also marked as backwards extremities */
-                INNER JOIN insertion_event_extremities as b USING (event_id)
+                FROM insertion_event_extremities AS insertion_event_extremity
                 /* Get the depth of the insertion event from the events table */
-                INNER JOIN events AS e USING (event_id)
-                WHERE b.room_id = ?
-                GROUP BY b.event_id
+                INNER JOIN events AS event USING (event_id)
+                /**
+                 * We use this info to make sure we don't retry to use a backfill point
+                 * if we've already attempted to backfill from it recently.
+                 */
+                LEFT JOIN event_failed_pull_attempts AS failed_backfill_attempt_info
+                ON
+                    failed_backfill_attempt_info.room_id = insertion_event_extremity.room_id
+                    AND failed_backfill_attempt_info.event_id = insertion_event_extremity.event_id
+                WHERE
+                    insertion_event_extremity.room_id = ?
+                    /**
+                     * Exponential back-off (up to the upper bound) so we don't retry the
+                     * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc
+                     *
+                     * We use `1 << n` as a power of 2 equivalent for compatibility
+                     * with older SQLites. The left shift equivalent only works with
+                     * powers of 2 because left shift is a binary operation (base-2).
+                     * Otherwise, we would use `power(2, n)` or the power operator, `2^n`.
+                     */
+                    AND (
+                        failed_backfill_attempt_info.event_id IS NULL
+                        OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */)
+                    )
+                /**
+                 * Sort from highest to the lowest depth. Then tie-break on
+                 * alphabetical order of the event_ids so we get a consistent
+                 * ordering which is nice when asserting things in tests.
+                 */
+                ORDER BY event.depth DESC, insertion_event_extremity.event_id DESC
             """
 
-            txn.execute(sql, (room_id,))
+            if isinstance(self.database_engine, PostgresEngine):
+                least_function = "least"
+            elif isinstance(self.database_engine, Sqlite3Engine):
+                least_function = "min"
+            else:
+                raise RuntimeError("Unknown database engine")
+
+            txn.execute(
+                sql % (least_function,),
+                (
+                    room_id,
+                    self._clock.time_msec(),
+                    1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS,
+                    1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS,
+                ),
+            )
             return cast(List[Tuple[str, int]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
@@ -1539,7 +1640,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         self,
         room_id: str,
     ) -> Optional[Tuple[str, str]]:
-        """Get the next event ID in the staging area for the given room."""
+        """
+        Get the next event ID in the staging area for the given room.
+
+        Returns:
+            Tuple of the `origin` and `event_id`
+        """
 
         def _get_next_staged_event_id_for_room_txn(
             txn: LoggingTransaction,
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index a6679e1312..85739c464e 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -12,25 +12,38 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Tuple, Union
+import datetime
+from typing import Dict, List, Tuple, Union
 
 import attr
 from parameterized import parameterized
 
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventTypes
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
     RoomVersion,
 )
 from synapse.events import _EventInternalMetadata
-from synapse.util import json_encoder
+from synapse.server import HomeServer
+from synapse.storage.database import LoggingTransaction
+from synapse.types import JsonDict
+from synapse.util import Clock, json_encoder
 
 import tests.unittest
 import tests.utils
 
 
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class _BackfillSetupInfo:
+    room_id: str
+    depth_map: Dict[str, int]
+
+
 class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
     def test_get_prev_events_for_room(self):
@@ -571,11 +584,471 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         )
         self.assertEqual(count, 1)
 
-        _, event_id = self.get_success(
+        next_staged_event_info = self.get_success(
             self.store.get_next_staged_event_id_for_room(room_id)
         )
+        assert next_staged_event_info
+        _, event_id = next_staged_event_info
         self.assertEqual(event_id, "$fake_event_id_500")
 
+    def _setup_room_for_backfill_tests(self) -> _BackfillSetupInfo:
+        """
+        Sets up a room with various events and backward extremities to test
+        backfill functions against.
+
+        Returns:
+            _BackfillSetupInfo including the `room_id` to test against and
+            `depth_map` of events in the room
+        """
+        room_id = "!backfill-room-test:some-host"
+
+        # The silly graph we use to test grabbing backward extremities,
+        # where the top is the oldest events.
+        #    1 (oldest)
+        #    |
+        #    2 ⹁
+        #    |  \
+        #    |   [b1, b2, b3]
+        #    |   |
+        #    |   A
+        #    |  /
+        #    3 {
+        #    |  \
+        #    |   [b4, b5, b6]
+        #    |   |
+        #    |   B
+        #    |  /
+        #    4 ´
+        #    |
+        #    5 (newest)
+
+        event_graph: Dict[str, List[str]] = {
+            "1": [],
+            "2": ["1"],
+            "3": ["2", "A"],
+            "4": ["3", "B"],
+            "5": ["4"],
+            "A": ["b1", "b2", "b3"],
+            "b1": ["2"],
+            "b2": ["2"],
+            "b3": ["2"],
+            "B": ["b4", "b5", "b6"],
+            "b4": ["3"],
+            "b5": ["3"],
+            "b6": ["3"],
+        }
+
+        depth_map: Dict[str, int] = {
+            "1": 1,
+            "2": 2,
+            "b1": 3,
+            "b2": 3,
+            "b3": 3,
+            "A": 4,
+            "3": 5,
+            "b4": 6,
+            "b5": 6,
+            "b6": 6,
+            "B": 7,
+            "4": 8,
+            "5": 9,
+        }
+
+        # The events we have persisted on our server.
+        # The rest are events in the room but not backfilled tet.
+        our_server_events = {"5", "4", "B", "3", "A"}
+
+        complete_event_dict_map: Dict[str, JsonDict] = {}
+        stream_ordering = 0
+        for (event_id, prev_event_ids) in event_graph.items():
+            depth = depth_map[event_id]
+
+            complete_event_dict_map[event_id] = {
+                "event_id": event_id,
+                "type": "test_regular_type",
+                "room_id": room_id,
+                "sender": "@sender",
+                "prev_event_ids": prev_event_ids,
+                "auth_event_ids": [],
+                "origin_server_ts": stream_ordering,
+                "depth": depth,
+                "stream_ordering": stream_ordering,
+                "content": {"body": "event" + event_id},
+            }
+
+            stream_ordering += 1
+
+        def populate_db(txn: LoggingTransaction):
+            # Insert the room to satisfy the foreign key constraint of
+            # `event_failed_pull_attempts`
+            self.store.db_pool.simple_insert_txn(
+                txn,
+                "rooms",
+                {
+                    "room_id": room_id,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                },
+            )
+
+            # Insert our server events
+            for event_id in our_server_events:
+                event_dict = complete_event_dict_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_dict.get("event_id"),
+                        "type": event_dict.get("type"),
+                        "room_id": event_dict.get("room_id"),
+                        "depth": event_dict.get("depth"),
+                        "topological_ordering": event_dict.get("depth"),
+                        "stream_ordering": event_dict.get("stream_ordering"),
+                        "processed": True,
+                        "outlier": False,
+                    },
+                )
+
+            # Insert the event edges
+            for event_id in our_server_events:
+                for prev_event_id in event_graph[event_id]:
+                    self.store.db_pool.simple_insert_txn(
+                        txn,
+                        table="event_edges",
+                        values={
+                            "event_id": event_id,
+                            "prev_event_id": prev_event_id,
+                            "room_id": room_id,
+                        },
+                    )
+
+            # Insert the backward extremities
+            prev_events_of_our_events = {
+                prev_event_id
+                for our_server_event in our_server_events
+                for prev_event_id in complete_event_dict_map[our_server_event][
+                    "prev_event_ids"
+                ]
+            }
+            backward_extremities = prev_events_of_our_events - our_server_events
+            for backward_extremity in backward_extremities:
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="event_backward_extremities",
+                    values={
+                        "event_id": backward_extremity,
+                        "room_id": room_id,
+                    },
+                )
+
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "_setup_room_for_backfill_tests_populate_db",
+                populate_db,
+            )
+        )
+
+        return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
+
+    def test_get_backfill_points_in_room(self):
+        """
+        Test to make sure we get some backfill points
+        """
+        setup_info = self._setup_room_for_backfill_tests()
+        room_id = setup_info.room_id
+
+        backfill_points = self.get_success(
+            self.store.get_backfill_points_in_room(room_id)
+        )
+        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+        self.assertListEqual(
+            backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]
+        )
+
+    def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
+        self,
+    ):
+        """
+        Test to make sure that events we have attempted to backfill (and within
+        backoff timeout duration) do not show up as an event to backfill again.
+        """
+        setup_info = self._setup_room_for_backfill_tests()
+        room_id = setup_info.room_id
+
+        # Record some attempts to backfill these events which will make
+        # `get_backfill_points_in_room` exclude them because we
+        # haven't passed the backoff interval.
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(room_id, "b5", "fake cause")
+        )
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(room_id, "b4", "fake cause")
+        )
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(room_id, "b3", "fake cause")
+        )
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(room_id, "b2", "fake cause")
+        )
+
+        # No time has passed since we attempted to backfill ^
+
+        backfill_points = self.get_success(
+            self.store.get_backfill_points_in_room(room_id)
+        )
+        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+        # Only the backfill points that we didn't record earlier exist here.
+        self.assertListEqual(backfill_event_ids, ["b6", "2", "b1"])
+
+    def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
+        self,
+    ):
+        """
+        Test to make sure after we fake attempt to backfill event "b3" many times,
+        we can see retry and see the "b3" again after the backoff timeout duration
+        has exceeded.
+        """
+        setup_info = self._setup_room_for_backfill_tests()
+        room_id = setup_info.room_id
+
+        # Record some attempts to backfill these events which will make
+        # `get_backfill_points_in_room` exclude them because we
+        # haven't passed the backoff interval.
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(room_id, "b3", "fake cause")
+        )
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+        )
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+        )
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+        )
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+        )
+
+        # Now advance time by 2 hours and we should only be able to see "b3"
+        # because we have waited long enough for the single attempt (2^1 hours)
+        # but we still shouldn't see "b1" because we haven't waited long enough
+        # for this many attempts. We didn't do anything to "b2" so it should be
+        # visible regardless.
+        self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
+
+        # Make sure that "b1" is not in the list because we've
+        # already attempted many times
+        backfill_points = self.get_success(
+            self.store.get_backfill_points_in_room(room_id)
+        )
+        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+        self.assertListEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2"])
+
+        # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
+        # see if we can now backfill it
+        self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
+
+        # Try again after we advanced enough time and we should see "b3" again
+        backfill_points = self.get_success(
+            self.store.get_backfill_points_in_room(room_id)
+        )
+        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+        self.assertListEqual(
+            backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]
+        )
+
+    def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo:
+        """
+        Sets up a room with various insertion event backward extremities to test
+        backfill functions against.
+
+        Returns:
+            _BackfillSetupInfo including the `room_id` to test against and
+            `depth_map` of events in the room
+        """
+        room_id = "!backfill-room-test:some-host"
+
+        depth_map: Dict[str, int] = {
+            "1": 1,
+            "2": 2,
+            "insertion_eventA": 3,
+            "3": 4,
+            "insertion_eventB": 5,
+            "4": 6,
+            "5": 7,
+        }
+
+        def populate_db(txn: LoggingTransaction):
+            # Insert the room to satisfy the foreign key constraint of
+            # `event_failed_pull_attempts`
+            self.store.db_pool.simple_insert_txn(
+                txn,
+                "rooms",
+                {
+                    "room_id": room_id,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                },
+            )
+
+            # Insert our server events
+            stream_ordering = 0
+            for event_id, depth in depth_map.items():
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "type": EventTypes.MSC2716_INSERTION
+                        if event_id.startswith("insertion_event")
+                        else "test_regular_type",
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "stream_ordering": stream_ordering,
+                        "processed": True,
+                        "outlier": False,
+                    },
+                )
+
+                if event_id.startswith("insertion_event"):
+                    self.store.db_pool.simple_insert_txn(
+                        txn,
+                        table="insertion_event_extremities",
+                        values={
+                            "event_id": event_id,
+                            "room_id": room_id,
+                        },
+                    )
+
+                stream_ordering += 1
+
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "_setup_room_for_insertion_backfill_tests_populate_db",
+                populate_db,
+            )
+        )
+
+        return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
+
+    def test_get_insertion_event_backward_extremities_in_room(self):
+        """
+        Test to make sure insertion event backward extremities are returned.
+        """
+        setup_info = self._setup_room_for_insertion_backfill_tests()
+        room_id = setup_info.room_id
+
+        backfill_points = self.get_success(
+            self.store.get_insertion_event_backward_extremities_in_room(room_id)
+        )
+        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+        self.assertListEqual(
+            backfill_event_ids, ["insertion_eventB", "insertion_eventA"]
+        )
+
+    def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
+        self,
+    ):
+        """
+        Test to make sure that insertion events we have attempted to backfill
+        (and within backoff timeout duration) do not show up as an event to
+        backfill again.
+        """
+        setup_info = self._setup_room_for_insertion_backfill_tests()
+        room_id = setup_info.room_id
+
+        # Record some attempts to backfill these events which will make
+        # `get_insertion_event_backward_extremities_in_room` exclude them
+        # because we haven't passed the backoff interval.
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(
+                room_id, "insertion_eventA", "fake cause"
+            )
+        )
+
+        # No time has passed since we attempted to backfill ^
+
+        backfill_points = self.get_success(
+            self.store.get_insertion_event_backward_extremities_in_room(room_id)
+        )
+        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+        # Only the backfill points that we didn't record earlier exist here.
+        self.assertListEqual(backfill_event_ids, ["insertion_eventB"])
+
+    def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
+        self,
+    ):
+        """
+        Test to make sure after we fake attempt to backfill event
+        "insertion_eventA" many times, we can see retry and see the
+        "insertion_eventA" again after the backoff timeout duration has
+        exceeded.
+        """
+        setup_info = self._setup_room_for_insertion_backfill_tests()
+        room_id = setup_info.room_id
+
+        # Record some attempts to backfill these events which will make
+        # `get_backfill_points_in_room` exclude them because we
+        # haven't passed the backoff interval.
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(
+                room_id, "insertion_eventB", "fake cause"
+            )
+        )
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(
+                room_id, "insertion_eventA", "fake cause"
+            )
+        )
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(
+                room_id, "insertion_eventA", "fake cause"
+            )
+        )
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(
+                room_id, "insertion_eventA", "fake cause"
+            )
+        )
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(
+                room_id, "insertion_eventA", "fake cause"
+            )
+        )
+
+        # Now advance time by 2 hours and we should only be able to see
+        # "insertion_eventB" because we have waited long enough for the single
+        # attempt (2^1 hours) but we still shouldn't see "insertion_eventA"
+        # because we haven't waited long enough for this many attempts.
+        self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
+
+        # Make sure that "insertion_eventA" is not in the list because we've
+        # already attempted many times
+        backfill_points = self.get_success(
+            self.store.get_insertion_event_backward_extremities_in_room(room_id)
+        )
+        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+        self.assertListEqual(backfill_event_ids, ["insertion_eventB"])
+
+        # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
+        # see if we can now backfill it
+        self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
+
+        # Try at "insertion_eventA" again after we advanced enough time and we
+        # should see "insertion_eventA" again
+        backfill_points = self.get_success(
+            self.store.get_insertion_event_backward_extremities_in_room(room_id)
+        )
+        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+        self.assertListEqual(
+            backfill_event_ids, ["insertion_eventB", "insertion_eventA"]
+        )
+
 
 @attr.s
 class FakeEvent: