diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index f39ae2d635..1529c86cc5 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -542,13 +542,15 @@ class EventsPersistenceStorageController:
return await res.get_state(self._state_controller, StateFilter.all())
async def _persist_event_batch(
- self, _room_id: str, task: _PersistEventsTask
+ self, room_id: str, task: _PersistEventsTask
) -> Dict[str, str]:
"""Callback for the _event_persist_queue
Calculates the change to current state and forward extremities, and
persists the given events and with those updates.
+ Assumes that we are only persisting events for one room at a time.
+
Returns:
A dictionary of event ID to event ID we didn't persist as we already
had another event persisted with the same TXN ID.
@@ -594,140 +596,23 @@ class EventsPersistenceStorageController:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
- # NB: Assumes that we are only persisting events for one room
- # at a time.
-
- # map room_id->set[event_ids] giving the new forward
- # extremities in each room
- new_forward_extremities: Dict[str, Set[str]] = {}
-
- # map room_id->(to_delete, to_insert) where to_delete is a list
- # of type/state keys to remove from current state, and to_insert
- # is a map (type,key)->event_id giving the state delta in each
- # room
- state_delta_for_room: Dict[str, DeltaState] = {}
+ new_forward_extremities = None
+ state_delta_for_room = None
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
- # Work out the new "current state" for each room.
+ # Work out the new "current state" for the room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
- events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
- for event, context in chunk:
- events_by_room.setdefault(event.room_id, []).append(
- (event, context)
- )
-
- for room_id, ev_ctx_rm in events_by_room.items():
- latest_event_ids = (
- await self.main_store.get_latest_event_ids_in_room(room_id)
- )
- new_latest_event_ids = await self._calculate_new_extremities(
- room_id, ev_ctx_rm, latest_event_ids
- )
-
- if new_latest_event_ids == latest_event_ids:
- # No change in extremities, so no change in state
- continue
-
- # there should always be at least one forward extremity.
- # (except during the initial persistence of the send_join
- # results, in which case there will be no existing
- # extremities, so we'll `continue` above and skip this bit.)
- assert new_latest_event_ids, "No forward extremities left!"
-
- new_forward_extremities[room_id] = new_latest_event_ids
-
- len_1 = (
- len(latest_event_ids) == 1
- and len(new_latest_event_ids) == 1
- )
- if len_1:
- all_single_prev_not_state = all(
- len(event.prev_event_ids()) == 1
- and not event.is_state()
- for event, ctx in ev_ctx_rm
- )
- # Don't bother calculating state if they're just
- # a long chain of single ancestor non-state events.
- if all_single_prev_not_state:
- continue
-
- state_delta_counter.inc()
- if len(new_latest_event_ids) == 1:
- state_delta_single_event_counter.inc()
-
- # This is a fairly handwavey check to see if we could
- # have guessed what the delta would have been when
- # processing one of these events.
- # What we're interested in is if the latest extremities
- # were the same when we created the event as they are
- # now. When this server creates a new event (as opposed
- # to receiving it over federation) it will use the
- # forward extremities as the prev_events, so we can
- # guess this by looking at the prev_events and checking
- # if they match the current forward extremities.
- for ev, _ in ev_ctx_rm:
- prev_event_ids = set(ev.prev_event_ids())
- if latest_event_ids == prev_event_ids:
- state_delta_reuse_delta_counter.inc()
- break
-
- logger.debug("Calculating state delta for room %s", room_id)
- with Measure(
- self._clock, "persist_events.get_new_state_after_events"
- ):
- res = await self._get_new_state_after_events(
- room_id,
- ev_ctx_rm,
- latest_event_ids,
- new_latest_event_ids,
- )
- current_state, delta_ids, new_latest_event_ids = res
-
- # there should always be at least one forward extremity.
- # (except during the initial persistence of the send_join
- # results, in which case there will be no existing
- # extremities, so we'll `continue` above and skip this bit.)
- assert new_latest_event_ids, "No forward extremities left!"
-
- new_forward_extremities[room_id] = new_latest_event_ids
-
- # If either are not None then there has been a change,
- # and we need to work out the delta (or use that
- # given)
- delta = None
- if delta_ids is not None:
- # If there is a delta we know that we've
- # only added or replaced state, never
- # removed keys entirely.
- delta = DeltaState([], delta_ids)
- elif current_state is not None:
- with Measure(
- self._clock, "persist_events.calculate_state_delta"
- ):
- delta = await self._calculate_state_delta(
- room_id, current_state
- )
-
- if delta:
- # If we have a change of state then lets check
- # whether we're actually still a member of the room,
- # or if our last user left. If we're no longer in
- # the room then we delete the current state and
- # extremities.
- is_still_joined = await self._is_server_still_joined(
- room_id,
- ev_ctx_rm,
- delta,
- )
- if not is_still_joined:
- logger.info("Server no longer in room %s", room_id)
- delta.no_longer_in_room = True
-
- state_delta_for_room[room_id] = delta
+ (
+ new_forward_extremities,
+ state_delta_for_room,
+ ) = await self._calculate_new_forward_extremities_and_state_delta(
+ room_id, chunk
+ )
await self.persist_events_store._persist_events_and_state_updates(
+ room_id,
chunk,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
@@ -737,6 +622,117 @@ class EventsPersistenceStorageController:
return replaced_events
+ async def _calculate_new_forward_extremities_and_state_delta(
+ self, room_id: str, ev_ctx_rm: List[Tuple[EventBase, EventContext]]
+ ) -> Tuple[Optional[Set[str]], Optional[DeltaState]]:
+ """Calculates the new forward extremities and state delta for a room
+ given events to persist.
+
+ Assumes that we are only persisting events for one room at a time.
+
+ Returns:
+ A tuple of:
+ A set of str giving the new forward extremities the room
+
+ The state delta for the room.
+ """
+
+ latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id)
+ new_latest_event_ids = await self._calculate_new_extremities(
+ room_id, ev_ctx_rm, latest_event_ids
+ )
+
+ if new_latest_event_ids == latest_event_ids:
+ # No change in extremities, so no change in state
+ return (None, None)
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremities = new_latest_event_ids
+
+ len_1 = len(latest_event_ids) == 1 and len(new_latest_event_ids) == 1
+ if len_1:
+ all_single_prev_not_state = all(
+ len(event.prev_event_ids()) == 1 and not event.is_state()
+ for event, ctx in ev_ctx_rm
+ )
+ # Don't bother calculating state if they're just
+ # a long chain of single ancestor non-state events.
+ if all_single_prev_not_state:
+ return (new_forward_extremities, None)
+
+ state_delta_counter.inc()
+ if len(new_latest_event_ids) == 1:
+ state_delta_single_event_counter.inc()
+
+ # This is a fairly handwavey check to see if we could
+ # have guessed what the delta would have been when
+ # processing one of these events.
+ # What we're interested in is if the latest extremities
+ # were the same when we created the event as they are
+ # now. When this server creates a new event (as opposed
+ # to receiving it over federation) it will use the
+ # forward extremities as the prev_events, so we can
+ # guess this by looking at the prev_events and checking
+ # if they match the current forward extremities.
+ for ev, _ in ev_ctx_rm:
+ prev_event_ids = set(ev.prev_event_ids())
+ if latest_event_ids == prev_event_ids:
+ state_delta_reuse_delta_counter.inc()
+ break
+
+ logger.debug("Calculating state delta for room %s", room_id)
+ with Measure(self._clock, "persist_events.get_new_state_after_events"):
+ res = await self._get_new_state_after_events(
+ room_id,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
+ )
+ current_state, delta_ids, new_latest_event_ids = res
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremities = new_latest_event_ids
+
+ # If either are not None then there has been a change,
+ # and we need to work out the delta (or use that
+ # given)
+ delta = None
+ if delta_ids is not None:
+ # If there is a delta we know that we've
+ # only added or replaced state, never
+ # removed keys entirely.
+ delta = DeltaState([], delta_ids)
+ elif current_state is not None:
+ with Measure(self._clock, "persist_events.calculate_state_delta"):
+ delta = await self._calculate_state_delta(room_id, current_state)
+
+ if delta:
+ # If we have a change of state then lets check
+ # whether we're actually still a member of the room,
+ # or if our last user left. If we're no longer in
+ # the room then we delete the current state and
+ # extremities.
+ is_still_joined = await self._is_server_still_joined(
+ room_id,
+ ev_ctx_rm,
+ delta,
+ )
+ if not is_still_joined:
+ logger.info("Server no longer in room %s", room_id)
+ delta.no_longer_in_room = True
+
+ return (new_forward_extremities, delta)
+
async def _calculate_new_extremities(
self,
room_id: str,
|