summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/12522.bugfix1
-rw-r--r--synapse/handlers/federation.py234
-rw-r--r--synapse/handlers/room_batch.py2
-rw-r--r--synapse/storage/databases/main/event_federation.py30
-rw-r--r--synapse/visibility.py7
5 files changed, 168 insertions, 106 deletions
diff --git a/changelog.d/12522.bugfix b/changelog.d/12522.bugfix
new file mode 100644
index 0000000000..2220f05ceb
--- /dev/null
+++ b/changelog.d/12522.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 0.99.3 which could cause Synapse to consume large amounts of RAM when back-paginating in a large room.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1434e99056..d2ba70a814 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
 # Copyright 2020 Sorunome
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,10 +15,14 @@
 
 """Contains handlers for federation events."""
 
+import enum
+import itertools
 import logging
+from enum import Enum
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
 
+import attr
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -92,6 +96,24 @@ def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
     return sorted(joined_domains.items(), key=lambda d: d[1])
 
 
+class _BackfillPointType(Enum):
+    # a regular backwards extremity (ie, an event which we don't yet have, but which
+    # is referred to by other events in the DAG)
+    BACKWARDS_EXTREMITY = enum.auto()
+
+    # an MSC2716 "insertion event"
+    INSERTION_PONT = enum.auto()
+
+
+@attr.s(slots=True, auto_attribs=True, frozen=True)
+class _BackfillPoint:
+    """A potential point we might backfill from"""
+
+    event_id: str
+    depth: int
+    type: _BackfillPointType
+
+
 class FederationHandler:
     """Handles general incoming federation requests
 
@@ -157,89 +179,51 @@ class FederationHandler:
     async def _maybe_backfill_inner(
         self, room_id: str, current_depth: int, limit: int
     ) -> bool:
-        oldest_events_with_depth = (
-            await self.store.get_oldest_event_ids_with_depth_in_room(room_id)
-        )
+        backwards_extremities = [
+            _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
+            for event_id, depth in await self.store.get_oldest_event_ids_with_depth_in_room(
+                room_id
+            )
+        ]
 
-        insertion_events_to_be_backfilled: Dict[str, int] = {}
+        insertion_events_to_be_backfilled: List[_BackfillPoint] = []
         if self.hs.config.experimental.msc2716_enabled:
-            insertion_events_to_be_backfilled = (
-                await self.store.get_insertion_event_backward_extremities_in_room(
+            insertion_events_to_be_backfilled = [
+                _BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT)
+                for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room(
                     room_id
                 )
-            )
+            ]
         logger.debug(
-            "_maybe_backfill_inner: extremities oldest_events_with_depth=%s insertion_events_to_be_backfilled=%s",
-            oldest_events_with_depth,
+            "_maybe_backfill_inner: backwards_extremities=%s insertion_events_to_be_backfilled=%s",
+            backwards_extremities,
             insertion_events_to_be_backfilled,
         )
 
-        if not oldest_events_with_depth and not insertion_events_to_be_backfilled:
+        if not backwards_extremities and not insertion_events_to_be_backfilled:
             logger.debug("Not backfilling as no extremeties found.")
             return False
 
-        # We only want to paginate if we can actually see the events we'll get,
-        # as otherwise we'll just spend a lot of resources to get redacted
-        # events.
-        #
-        # We do this by filtering all the backwards extremities and seeing if
-        # any remain. Given we don't have the extremity events themselves, we
-        # need to actually check the events that reference them.
-        #
-        # *Note*: the spec wants us to keep backfilling until we reach the start
-        # of the room in case we are allowed to see some of the history. However
-        # in practice that causes more issues than its worth, as a) its
-        # relatively rare for there to be any visible history and b) even when
-        # there is its often sufficiently long ago that clients would stop
-        # attempting to paginate before backfill reached the visible history.
-        #
-        # TODO: If we do do a backfill then we should filter the backwards
-        #   extremities to only include those that point to visible portions of
-        #   history.
-        #
-        # TODO: Correctly handle the case where we are allowed to see the
-        #   forward event but not the backward extremity, e.g. in the case of
-        #   initial join of the server where we are allowed to see the join
-        #   event but not anything before it. This would require looking at the
-        #   state *before* the event, ignoring the special casing certain event
-        #   types have.
-
-        forward_event_ids = await self.store.get_successor_events(
-            list(oldest_events_with_depth)
-        )
-
-        extremities_events = await self.store.get_events(
-            forward_event_ids,
-            redact_behaviour=EventRedactBehaviour.AS_IS,
-            get_prev_content=False,
+        # we now have a list of potential places to backpaginate from. We prefer to
+        # start with the most recent (ie, max depth), so let's sort the list.
+        sorted_backfill_points: List[_BackfillPoint] = sorted(
+            itertools.chain(
+                backwards_extremities,
+                insertion_events_to_be_backfilled,
+            ),
+            key=lambda e: -int(e.depth),
         )
 
-        # We set `check_history_visibility_only` as we might otherwise get false
-        # positives from users having been erased.
-        filtered_extremities = await filter_events_for_server(
-            self.storage,
-            self.server_name,
-            list(extremities_events.values()),
-            redact=False,
-            check_history_visibility_only=True,
-        )
         logger.debug(
-            "_maybe_backfill_inner: filtered_extremities %s", filtered_extremities
+            "_maybe_backfill_inner: room_id: %s: current_depth: %s, limit: %s, "
+            "backfill points (%d): %s",
+            room_id,
+            current_depth,
+            limit,
+            len(sorted_backfill_points),
+            sorted_backfill_points,
         )
 
-        if not filtered_extremities and not insertion_events_to_be_backfilled:
-            return False
-
-        extremities = {
-            **oldest_events_with_depth,
-            # TODO: insertion_events_to_be_backfilled is currently skipping the filtered_extremities checks
-            **insertion_events_to_be_backfilled,
-        }
-
-        # Check if we reached a point where we should start backfilling.
-        sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
-        max_depth = sorted_extremeties_tuple[0][1]
-
         # If we're approaching an extremity we trigger a backfill, otherwise we
         # no-op.
         #
@@ -249,6 +233,11 @@ class FederationHandler:
         # chose more than one times the limit in case of failure, but choosing a
         # much larger factor will result in triggering a backfill request much
         # earlier than necessary.
+        #
+        # XXX: shouldn't we do this *after* the filter by depth below? Again, we don't
+        # care about events that have happened after our current position.
+        #
+        max_depth = sorted_backfill_points[0].depth
         if current_depth - 2 * limit > max_depth:
             logger.debug(
                 "Not backfilling as we don't need to. %d < %d - 2 * %d",
@@ -265,31 +254,98 @@ class FederationHandler:
         #    2. we have likely previously tried and failed to backfill from that
         #       extremity, so to avoid getting "stuck" requesting the same
         #       backfill repeatedly we drop those extremities.
-        filtered_sorted_extremeties_tuple = [
-            t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
-        ]
-
-        logger.debug(
-            "room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems (%d): %s filtered_sorted_extremeties_tuple: %s",
-            room_id,
-            current_depth,
-            limit,
-            max_depth,
-            len(sorted_extremeties_tuple),
-            sorted_extremeties_tuple,
-            filtered_sorted_extremeties_tuple,
-        )
-
+        #
         # However, we need to check that the filtered extremities are non-empty.
         # If they are empty then either we can a) bail or b) still attempt to
         # backfill. We opt to try backfilling anyway just in case we do get
         # relevant events.
-        if filtered_sorted_extremeties_tuple:
-            sorted_extremeties_tuple = filtered_sorted_extremeties_tuple
+        #
+        filtered_sorted_backfill_points = [
+            t for t in sorted_backfill_points if t.depth <= current_depth
+        ]
+        if filtered_sorted_backfill_points:
+            logger.debug(
+                "_maybe_backfill_inner: backfill points before current depth: %s",
+                filtered_sorted_backfill_points,
+            )
+            sorted_backfill_points = filtered_sorted_backfill_points
+        else:
+            logger.debug(
+                "_maybe_backfill_inner: all backfill points are *after* current depth. Backfilling anyway."
+            )
+
+        # For performance's sake, we only want to paginate from a particular extremity
+        # if we can actually see the events we'll get. Otherwise, we'd just spend a lot
+        # of resources to get redacted events. We check each extremity in turn and
+        # ignore those which users on our server wouldn't be able to see.
+        #
+        # Additionally, we limit ourselves to backfilling from at most 5 extremities,
+        # for two reasons:
+        #
+        # - The check which determines if we can see an extremity's events can be
+        #   expensive (we load the full state for the room at each of the backfill
+        #   points, or (worse) their successors)
+        # - We want to avoid the server-server API request URI becoming too long.
+        #
+        # *Note*: the spec wants us to keep backfilling until we reach the start
+        # of the room in case we are allowed to see some of the history. However,
+        # in practice that causes more issues than its worth, as (a) it's
+        # relatively rare for there to be any visible history and (b) even when
+        # there is it's often sufficiently long ago that clients would stop
+        # attempting to paginate before backfill reached the visible history.
 
-        # We don't want to specify too many extremities as it causes the backfill
-        # request URI to be too long.
-        extremities = dict(sorted_extremeties_tuple[:5])
+        extremities_to_request: List[str] = []
+        for bp in sorted_backfill_points:
+            if len(extremities_to_request) >= 5:
+                break
+
+            # For regular backwards extremities, we don't have the extremity events
+            # themselves, so we need to actually check the events that reference them -
+            # their "successor" events.
+            #
+            # TODO: Correctly handle the case where we are allowed to see the
+            #   successor event but not the backward extremity, e.g. in the case of
+            #   initial join of the server where we are allowed to see the join
+            #   event but not anything before it. This would require looking at the
+            #   state *before* the event, ignoring the special casing certain event
+            #   types have.
+            if bp.type == _BackfillPointType.INSERTION_PONT:
+                event_ids_to_check = [bp.event_id]
+            else:
+                event_ids_to_check = await self.store.get_successor_events(bp.event_id)
+
+            events_to_check = await self.store.get_events_as_list(
+                event_ids_to_check,
+                redact_behaviour=EventRedactBehaviour.AS_IS,
+                get_prev_content=False,
+            )
+
+            # We set `check_history_visibility_only` as we might otherwise get false
+            # positives from users having been erased.
+            filtered_extremities = await filter_events_for_server(
+                self.storage,
+                self.server_name,
+                events_to_check,
+                redact=False,
+                check_history_visibility_only=True,
+            )
+            if filtered_extremities:
+                extremities_to_request.append(bp.event_id)
+            else:
+                logger.debug(
+                    "_maybe_backfill_inner: skipping extremity %s as it would not be visible",
+                    bp,
+                )
+
+        if not extremities_to_request:
+            logger.debug(
+                "_maybe_backfill_inner: found no extremities which would be visible"
+            )
+            return False
+
+        logger.debug(
+            "_maybe_backfill_inner: extremities_to_request %s", extremities_to_request
+        )
 
         # Now we need to decide which hosts to hit first.
 
@@ -309,7 +365,7 @@ class FederationHandler:
             for dom in domains:
                 try:
                     await self._federation_event_handler.backfill(
-                        dom, room_id, limit=100, extremities=extremities
+                        dom, room_id, limit=100, extremities=extremities_to_request
                     )
                     # If this succeeded then we probably already have the
                     # appropriate stuff.
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 78e299d3a5..29de7e5bed 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -54,7 +54,7 @@ class RoomBatchHandler:
         # it has a larger `depth` but before the successor event because the `stream_ordering`
         # is negative before the successor event.
         successor_event_ids = await self.store.get_successor_events(
-            [most_recent_prev_event_id]
+            most_recent_prev_event_id
         )
 
         # If we can't find any successor events, then it's a forward extremity of
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 634e19e035..4710224708 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -695,7 +695,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Return all events where not all sets can reach them.
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
-    async def get_oldest_event_ids_with_depth_in_room(self, room_id) -> Dict[str, int]:
+    async def get_oldest_event_ids_with_depth_in_room(
+        self, room_id
+    ) -> List[Tuple[str, int]]:
         """Gets the oldest events(backwards extremities) in the room along with the
         aproximate depth.
 
@@ -708,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id: Room where we want to find the oldest events
 
         Returns:
-            Map from event_id to depth
+            List of (event_id, depth) tuples
         """
 
         def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
@@ -741,7 +743,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             txn.execute(sql, (room_id, False))
 
-            return dict(txn)
+            return txn.fetchall()
 
         return await self.db_pool.runInteraction(
             "get_oldest_event_ids_with_depth_in_room",
@@ -751,7 +753,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
     async def get_insertion_event_backward_extremities_in_room(
         self, room_id
-    ) -> Dict[str, int]:
+    ) -> List[Tuple[str, int]]:
         """Get the insertion events we know about that we haven't backfilled yet.
 
         We use this function so that we can compare and see if someones current
@@ -763,7 +765,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id: Room where we want to find the oldest events
 
         Returns:
-            Map from event_id to depth
+            List of (event_id, depth) tuples
         """
 
         def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
@@ -778,8 +780,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             """
 
             txn.execute(sql, (room_id,))
-
-            return dict(txn)
+            return txn.fetchall()
 
         return await self.db_pool.runInteraction(
             "get_insertion_event_backward_extremities_in_room",
@@ -1295,22 +1296,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         event_results.reverse()
         return event_results
 
-    async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]:
-        """Fetch all events that have the given events as a prev event
+    async def get_successor_events(self, event_id: str) -> List[str]:
+        """Fetch all events that have the given event as a prev event
 
         Args:
-            event_ids: The events to use as the previous events.
+            event_id: The event to search for as a prev_event.
         """
-        rows = await self.db_pool.simple_select_many_batch(
+        return await self.db_pool.simple_select_onecol(
             table="event_edges",
-            column="prev_event_id",
-            iterable=event_ids,
-            retcols=("event_id",),
+            keyvalues={"prev_event_id": event_id},
+            retcol="event_id",
             desc="get_successor_events",
         )
 
-        return [row["event_id"] for row in rows]
-
     @wrap_as_background_process("delete_old_forward_extrem_cache")
     async def _delete_old_forward_extrem_cache(self) -> None:
         def _delete_old_forward_extrem_cache_txn(txn):
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 250f073597..de6d2ffc52 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -419,6 +419,13 @@ async def _event_to_memberships(
         return {}
 
     # for each event, get the event_ids of the membership state at those events.
+    #
+    # TODO: this means that we request the entire membership list. If there  are only
+    #   one or two users on this server, and the room is huge, this is very wasteful
+    #   (it means more db work, and churns the *stateGroupMembersCache*).
+    #   It might be that we could extend StateFilter to specify "give me keys matching
+    #   *:<server_name>", to avoid this.
+
     event_to_state_ids = await storage.state.get_state_ids_for_events(
         frozenset(e.event_id for e in events),
         state_filter=StateFilter.from_types(types=((EventTypes.Member, None),)),