diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 70e636b0ba..61fc49c69c 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,14 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import (
+ Collection,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StateMap,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -68,6 +75,21 @@ stale_forward_extremities_counter = Histogram(
buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
)
+state_resolutions_during_persistence = Counter(
+ "synapse_storage_events_state_resolutions_during_persistence",
+ "Number of times we had to do state res to calculate new current state",
+)
+
+potential_times_prune_extremities = Counter(
+ "synapse_storage_events_potential_times_prune_extremities",
+ "Number of times we might be able to prune extremities",
+)
+
+times_pruned_extremities = Counter(
+ "synapse_storage_events_times_pruned_extremities",
+ "Number of times we were actually be able to prune extremities",
+)
+
class _EventPeristenceQueue:
"""Queues up events so that they can be persisted in bulk with only one
@@ -454,7 +476,15 @@ class EventsPersistenceStorage:
latest_event_ids,
new_latest_event_ids,
)
- current_state, delta_ids = res
+ current_state, delta_ids, new_latest_event_ids = res
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremeties[room_id] = new_latest_event_ids
# If either are not None then there has been a change,
# and we need to work out the delta (or use that
@@ -573,29 +603,35 @@ class EventsPersistenceStorage:
self,
room_id: str,
events_context: List[Tuple[EventBase, EventContext]],
- old_latest_event_ids: Iterable[str],
- new_latest_event_ids: Iterable[str],
- ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
+ old_latest_event_ids: Set[str],
+ new_latest_event_ids: Set[str],
+ ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
"""Calculate the current state dict after adding some new events to
a room
Args:
- room_id (str):
+ room_id:
room to which the events are being added. Used for logging etc
- events_context (list[(EventBase, EventContext)]):
+ events_context:
events and contexts which are being added to the room
- old_latest_event_ids (iterable[str]):
+ old_latest_event_ids:
the old forward extremities for the room.
- new_latest_event_ids (iterable[str]):
+ new_latest_event_ids :
the new forward extremities for the room.
Returns:
- Returns a tuple of two state maps, the first being the full new current
- state and the second being the delta to the existing current state.
- If both are None then there has been no change.
+ Returns a tuple of two state maps and a set of new forward
+ extremities.
+
+ The first state map is the full new current state and the second
+ is the delta to the existing current state. If both are None then
+ there has been no change.
+
+ The function may prune some old entries from the set of new
+ forward extremities if it's safe to do so.
If there has been a change then we only return the delta if its
already been calculated. Conversely if we do know the delta then
@@ -672,7 +708,7 @@ class EventsPersistenceStorage:
# If they old and new groups are the same then we don't need to do
# anything.
if old_state_groups == new_state_groups:
- return None, None
+ return None, None, new_latest_event_ids
if len(new_state_groups) == 1 and len(old_state_groups) == 1:
# If we're going from one state group to another, lets check if
@@ -689,7 +725,7 @@ class EventsPersistenceStorage:
# the current state in memory then lets also return that,
# but it doesn't matter if we don't.
new_state = state_groups_map.get(new_state_group)
- return new_state, delta_ids
+ return new_state, delta_ids, new_latest_event_ids
# Now that we have calculated new_state_groups we need to get
# their state IDs so we can resolve to a single state set.
@@ -701,7 +737,7 @@ class EventsPersistenceStorage:
if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- return state_groups_map[new_state_groups.pop()], None
+ return state_groups_map[new_state_groups.pop()], None, new_latest_event_ids
# Ok, we need to defer to the state handler to resolve our state sets.
@@ -734,7 +770,139 @@ class EventsPersistenceStorage:
state_res_store=StateResolutionStore(self.main_store),
)
- return res.state, None
+ state_resolutions_during_persistence.inc()
+
+ # If the returned state matches the state group of one of the new
+ # forward extremities then we check if we are able to prune some state
+ # extremities.
+ if res.state_group and res.state_group in new_state_groups:
+ new_latest_event_ids = await self._prune_extremities(
+ room_id,
+ new_latest_event_ids,
+ res.state_group,
+ event_id_to_state_group,
+ events_context,
+ )
+
+ return res.state, None, new_latest_event_ids
+
+ async def _prune_extremities(
+ self,
+ room_id: str,
+ new_latest_event_ids: Set[str],
+ resolved_state_group: int,
+ event_id_to_state_group: Dict[str, int],
+ events_context: List[Tuple[EventBase, EventContext]],
+ ) -> Set[str]:
+ """See if we can prune any of the extremities after calculating the
+ resolved state.
+ """
+ potential_times_prune_extremities.inc()
+
+ # We keep all the extremities that have the same state group, and
+ # see if we can drop the others.
+ new_new_extrems = {
+ e
+ for e in new_latest_event_ids
+ if event_id_to_state_group[e] == resolved_state_group
+ }
+
+ dropped_extrems = set(new_latest_event_ids) - new_new_extrems
+
+ logger.debug("Might drop extremities: %s", dropped_extrems)
+
+ # We only drop events from the extremities list if:
+ # 1. we're not currently persisting them;
+ # 2. they're not our own events (or are dummy events); and
+ # 3. they're either:
+ # 1. over N hours old and more than N events ago (we use depth to
+ # calculate); or
+ # 2. we are persisting an event from the same domain and more than
+ # M events ago.
+ #
+ # The idea is that we don't want to drop events that are "legitimate"
+ # extremities (that we would want to include as prev events), only
+ # "stuck" extremities that are e.g. due to a gap in the graph.
+ #
+ # Note that we either drop all of them or none of them. If we only drop
+ # some of the events we don't know if state res would come to the same
+ # conclusion.
+
+ for ev, _ in events_context:
+ if ev.event_id in dropped_extrems:
+ logger.debug(
+ "Not dropping extremities: %s is being persisted", ev.event_id
+ )
+ return new_latest_event_ids
+
+ dropped_events = await self.main_store.get_events(
+ dropped_extrems,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+
+ new_senders = {get_domain_from_id(e.sender) for e, _ in events_context}
+
+ one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ current_depth = max(e.depth for e, _ in events_context)
+ for event in dropped_events.values():
+ # If the event is a local dummy event then we should check it
+ # doesn't reference any local events, as we want to reference those
+ # if we send any new events.
+ #
+ # Note we do this recursively to handle the case where a dummy event
+ # references a dummy event that only references remote events.
+ #
+ # Ideally we'd figure out a way of still being able to drop old
+ # dummy events that reference local events, but this is good enough
+ # as a first cut.
+ events_to_check = [event]
+ while events_to_check:
+ new_events = set()
+ for event_to_check in events_to_check:
+ if self.is_mine_id(event_to_check.sender):
+ if event_to_check.type != EventTypes.Dummy:
+ logger.debug("Not dropping own event")
+ return new_latest_event_ids
+ new_events.update(event_to_check.prev_event_ids())
+
+ prev_events = await self.main_store.get_events(
+ new_events,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+ events_to_check = prev_events.values()
+
+ if (
+ event.origin_server_ts < one_day_ago
+ and event.depth < current_depth - 100
+ ):
+ continue
+
+ # We can be less conservative about dropping extremities from the
+ # same domain, though we do want to wait a little bit (otherwise
+ # we'll immediately remove all extremities from a given server).
+ if (
+ get_domain_from_id(event.sender) in new_senders
+ and event.depth < current_depth - 20
+ ):
+ continue
+
+ logger.debug(
+ "Not dropping as too new and not in new_senders: %s", new_senders,
+ )
+
+ return new_latest_event_ids
+
+ times_pruned_extremities.inc()
+
+ logger.info(
+ "Pruning forward extremities in room %s: from %s -> %s",
+ room_id,
+ new_latest_event_ids,
+ new_new_extrems,
+ )
+ return new_new_extrems
async def _calculate_state_delta(
self, room_id: str, current_state: StateMap[str]
|