diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 547e43ab98..bddf5ef192 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -16,11 +16,11 @@ import logging
from queue import Empty, PriorityQueue
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
-from prometheus_client import Gauge
+from prometheus_client import Counter, Gauge
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError
-from synapse.api.room_versions import RoomVersion
+from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -44,6 +44,12 @@ number_pdus_in_federation_queue = Gauge(
"The total number of events in the inbound federation staging",
)
+pdus_pruned_from_federation_queue = Counter(
+ "synapse_federation_server_number_inbound_pdu_pruned",
+ "The number of events in the inbound federation staging that have been "
+ "pruned due to the queue getting too long",
+)
+
logger = logging.getLogger(__name__)
@@ -665,27 +671,97 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
- async def get_oldest_events_with_depth_in_room(self, room_id):
+ async def get_oldest_event_ids_with_depth_in_room(self, room_id) -> Dict[str, int]:
+ """Gets the oldest events(backwards extremities) in the room along with the
+ aproximate depth.
+
+ We use this function so that we can compare and see if someones current
+ depth at their current scrollback is within pagination range of the
+ event extremeties. If the current depth is close to the depth of given
+ oldest event, we can trigger a backfill.
+
+ Args:
+ room_id: Room where we want to find the oldest events
+
+ Returns:
+ Map from event_id to depth
+ """
+
+ def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
+ # Assemble a dictionary with event_id -> depth for the oldest events
+ # we know of in the room. Backwards extremeties are the oldest
+ # events we know of in the room but we only know of them because
+ # some other event referenced them by prev_event and aren't peristed
+ # in our database yet (meaning we don't know their depth
+ # specifically). So we need to look for the aproximate depth from
+ # the events connected to the current backwards extremeties.
+ sql = """
+ SELECT b.event_id, MAX(e.depth) FROM events as e
+ /**
+ * Get the edge connections from the event_edges table
+ * so we can see whether this event's prev_events points
+ * to a backward extremity in the next join.
+ */
+ INNER JOIN event_edges as g
+ ON g.event_id = e.event_id
+ /**
+ * We find the "oldest" events in the room by looking for
+ * events connected to backwards extremeties (oldest events
+ * in the room that we know of so far).
+ */
+ INNER JOIN event_backward_extremities as b
+ ON g.prev_event_id = b.event_id
+ WHERE b.room_id = ? AND g.is_state is ?
+ GROUP BY b.event_id
+ """
+
+ txn.execute(sql, (room_id, False))
+
+ return dict(txn)
+
return await self.db_pool.runInteraction(
- "get_oldest_events_with_depth_in_room",
- self.get_oldest_events_with_depth_in_room_txn,
+ "get_oldest_event_ids_with_depth_in_room",
+ get_oldest_event_ids_with_depth_in_room_txn,
room_id,
)
- def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
- sql = (
- "SELECT b.event_id, MAX(e.depth) FROM events as e"
- " INNER JOIN event_edges as g"
- " ON g.event_id = e.event_id"
- " INNER JOIN event_backward_extremities as b"
- " ON g.prev_event_id = b.event_id"
- " WHERE b.room_id = ? AND g.is_state is ?"
- " GROUP BY b.event_id"
- )
+ async def get_insertion_event_backwards_extremities_in_room(
+ self, room_id
+ ) -> Dict[str, int]:
+ """Get the insertion events we know about that we haven't backfilled yet.
- txn.execute(sql, (room_id, False))
+ We use this function so that we can compare and see if someones current
+ depth at their current scrollback is within pagination range of the
+ insertion event. If the current depth is close to the depth of given
+ insertion event, we can trigger a backfill.
- return dict(txn)
+ Args:
+ room_id: Room where we want to find the oldest events
+
+ Returns:
+ Map from event_id to depth
+ """
+
+ def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id):
+ sql = """
+ SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
+ /* We only want insertion events that are also marked as backwards extremities */
+ INNER JOIN insertion_event_extremities as b USING (event_id)
+ /* Get the depth of the insertion event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ WHERE b.room_id = ?
+ GROUP BY b.event_id
+ """
+
+ txn.execute(sql, (room_id,))
+
+ return dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "get_insertion_event_backwards_extremities_in_room",
+ get_insertion_event_backwards_extremities_in_room_txn,
+ room_id,
+ )
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
@@ -1035,7 +1111,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
if row[1] not in event_results:
queue.put((-row[0], row[1]))
- # Navigate up the DAG by prev_event
txn.execute(query, (event_id, False, limit - len(event_results)))
prev_event_id_results = txn.fetchall()
logger.debug(
@@ -1130,6 +1205,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
_delete_old_forward_extrem_cache_txn,
)
+ async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
+ await self.db_pool.simple_upsert(
+ table="insertion_event_extremities",
+ keyvalues={"event_id": event_id},
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ },
+ insertion_values={},
+ desc="insert_insertion_extremity",
+ lock=False,
+ )
+
async def insert_received_event_to_staging(
self, origin: str, event: EventBase
) -> None:
@@ -1277,6 +1365,100 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return origin, event
+ async def prune_staged_events_in_room(
+ self,
+ room_id: str,
+ room_version: RoomVersion,
+ ) -> bool:
+ """Checks if there are lots of staged events for the room, and if so
+ prune them down.
+
+ Returns:
+ Whether any events were pruned
+ """
+
+ # First check the size of the queue.
+ count = await self.db_pool.simple_select_one_onecol(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ retcol="COALESCE(COUNT(*), 0)",
+ desc="prune_staged_events_in_room_count",
+ )
+
+ if count < 100:
+ return False
+
+ # If the queue is too large, then we want clear the entire queue,
+ # keeping only the forward extremities (i.e. the events not referenced
+ # by other events in the queue). We do this so that we can always
+ # backpaginate in all the events we have dropped.
+ rows = await self.db_pool.simple_select_list(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "event_json"),
+ desc="prune_staged_events_in_room_fetch",
+ )
+
+ # Find the set of events referenced by those in the queue, as well as
+ # collecting all the event IDs in the queue.
+ referenced_events: Set[str] = set()
+ seen_events: Set[str] = set()
+ for row in rows:
+ event_id = row["event_id"]
+ seen_events.add(event_id)
+ event_d = db_to_json(row["event_json"])
+
+ # We don't bother parsing the dicts into full blown event objects,
+ # as that is needlessly expensive.
+
+ # We haven't checked that the `prev_events` have the right format
+ # yet, so we check as we go.
+ prev_events = event_d.get("prev_events", [])
+ if not isinstance(prev_events, list):
+ logger.info("Invalid prev_events for %s", event_id)
+ continue
+
+ if room_version.event_format == EventFormatVersions.V1:
+ for prev_event_tuple in prev_events:
+ if not isinstance(prev_event_tuple, list) or len(prev_events) != 2:
+ logger.info("Invalid prev_events for %s", event_id)
+ break
+
+ prev_event_id = prev_event_tuple[0]
+ if not isinstance(prev_event_id, str):
+ logger.info("Invalid prev_events for %s", event_id)
+ break
+
+ referenced_events.add(prev_event_id)
+ else:
+ for prev_event_id in prev_events:
+ if not isinstance(prev_event_id, str):
+ logger.info("Invalid prev_events for %s", event_id)
+ break
+
+ referenced_events.add(prev_event_id)
+
+ to_delete = referenced_events & seen_events
+ if not to_delete:
+ return False
+
+ pdus_pruned_from_federation_queue.inc(len(to_delete))
+ logger.info(
+ "Pruning %d events in room %s from federation queue",
+ len(to_delete),
+ room_id,
+ )
+
+ await self.db_pool.simple_delete_many(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ iterable=to_delete,
+ column="event_id",
+ desc="prune_staged_events_in_room_delete",
+ )
+
+ return True
+
async def get_all_rooms_with_staged_incoming_events(self) -> List[str]:
"""Get the room IDs of all events currently staged."""
return await self.db_pool.simple_select_onecol(
|