diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
new file mode 100644
index 0000000000..992261d07b
--- /dev/null
+++ b/synapse/storage/controllers/__init__.py
@@ -0,0 +1,46 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from synapse.storage.controllers.persist_events import (
+ EventsPersistenceStorageController,
+)
+from synapse.storage.controllers.purge_events import PurgeEventsStorageController
+from synapse.storage.controllers.state import StateGroupStorageController
+from synapse.storage.databases import Databases
+from synapse.storage.databases.main import DataStore
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+__all__ = ["Databases", "DataStore"]
+
+
+class StorageControllers:
+ """The high level interfaces for talking to various storage controller layers."""
+
+ def __init__(self, hs: "HomeServer", stores: Databases):
+ # We include the main data store here mainly so that we don't have to
+ # rewrite all the existing code to split it into high vs low level
+ # interfaces.
+ self.main = stores.main
+
+ self.purge_events = PurgeEventsStorageController(hs, stores)
+ self.state = StateGroupStorageController(hs, stores)
+
+ self.persistence = None
+ if stores.persist_events:
+ self.persistence = EventsPersistenceStorageController(hs, stores)
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
new file mode 100644
index 0000000000..ef8c135b12
--- /dev/null
+++ b/synapse/storage/controllers/persist_events.py
@@ -0,0 +1,1124 @@
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+from collections import deque
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Collection,
+ Deque,
+ Dict,
+ Generator,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
+
+import attr
+from prometheus_client import Counter, Histogram
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.logging import opentracing
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases import Databases
+from synapse.storage.databases.main.events import DeltaState
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import (
+ PersistedEventPosition,
+ RoomStreamToken,
+ StateMap,
+ get_domain_from_id,
+)
+from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results
+from synapse.util.metrics import Measure
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+# The number of times we are recalculating the current state
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
+
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = Counter(
+ "synapse_storage_events_state_delta_single_event", ""
+)
+
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = Counter(
+ "synapse_storage_events_state_delta_reuse_delta", ""
+)
+
+# The number of forward extremities for each new event.
+forward_extremities_counter = Histogram(
+ "synapse_storage_events_forward_extremities_persisted",
+ "Number of forward extremities for each new event",
+ buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+# The number of stale forward extremities for each new event. Stale extremities
+# are those that were in the previous set of extremities as well as the new.
+stale_forward_extremities_counter = Histogram(
+ "synapse_storage_events_stale_forward_extremities_persisted",
+ "Number of unchanged forward extremities for each new event",
+ buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+state_resolutions_during_persistence = Counter(
+ "synapse_storage_events_state_resolutions_during_persistence",
+ "Number of times we had to do state res to calculate new current state",
+)
+
+potential_times_prune_extremities = Counter(
+ "synapse_storage_events_potential_times_prune_extremities",
+ "Number of times we might be able to prune extremities",
+)
+
+times_pruned_extremities = Counter(
+ "synapse_storage_events_times_pruned_extremities",
+ "Number of times we were actually be able to prune extremities",
+)
+
+
+@attr.s(auto_attribs=True, slots=True)
+class _EventPersistQueueItem:
+ events_and_contexts: List[Tuple[EventBase, EventContext]]
+ backfilled: bool
+ deferred: ObservableDeferred
+
+ parent_opentracing_span_contexts: List = attr.ib(factory=list)
+ """A list of opentracing spans waiting for this batch"""
+
+ opentracing_span_context: Any = None
+ """The opentracing span under which the persistence actually happened"""
+
+
+_PersistResult = TypeVar("_PersistResult")
+
+
+class _EventPeristenceQueue(Generic[_PersistResult]):
+ """Queues up events so that they can be persisted in bulk with only one
+ concurrent transaction per room.
+ """
+
+ def __init__(
+ self,
+ per_item_callback: Callable[
+ [List[Tuple[EventBase, EventContext]], bool],
+ Awaitable[_PersistResult],
+ ],
+ ):
+ """Create a new event persistence queue
+
+ The per_item_callback will be called for each item added via add_to_queue,
+ and its result will be returned via the Deferreds returned from add_to_queue.
+ """
+ self._event_persist_queues: Dict[str, Deque[_EventPersistQueueItem]] = {}
+ self._currently_persisting_rooms: Set[str] = set()
+ self._per_item_callback = per_item_callback
+
+ async def add_to_queue(
+ self,
+ room_id: str,
+ events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
+ backfilled: bool,
+ ) -> _PersistResult:
+ """Add events to the queue, with the given persist_event options.
+
+ If we are not already processing events in this room, starts off a background
+ process to to so, calling the per_item_callback for each item.
+
+ Args:
+ room_id (str):
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
+
+ Returns:
+ the result returned by the `_per_item_callback` passed to
+ `__init__`.
+ """
+ queue = self._event_persist_queues.setdefault(room_id, deque())
+
+ # if the last item in the queue has the same `backfilled` setting,
+ # we can just add these new events to that item.
+ if queue and queue[-1].backfilled == backfilled:
+ end_item = queue[-1]
+ else:
+ # need to make a new queue item
+ deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
+ defer.Deferred(), consumeErrors=True
+ )
+
+ end_item = _EventPersistQueueItem(
+ events_and_contexts=[],
+ backfilled=backfilled,
+ deferred=deferred,
+ )
+ queue.append(end_item)
+
+ # add our events to the queue item
+ end_item.events_and_contexts.extend(events_and_contexts)
+
+ # also add our active opentracing span to the item so that we get a link back
+ span = opentracing.active_span()
+ if span:
+ end_item.parent_opentracing_span_contexts.append(span.context)
+
+ # start a processor for the queue, if there isn't one already
+ self._handle_queue(room_id)
+
+ # wait for the queue item to complete
+ res = await make_deferred_yieldable(end_item.deferred.observe())
+
+ # add another opentracing span which links to the persist trace.
+ with opentracing.start_active_span_follows_from(
+ "persist_event_batch_complete", (end_item.opentracing_span_context,)
+ ):
+ pass
+
+ return res
+
+ def _handle_queue(self, room_id: str) -> None:
+ """Attempts to handle the queue for a room if not already being handled.
+
+ The queue's callback will be invoked with for each item in the queue,
+ of type _EventPersistQueueItem. The per_item_callback will continuously
+ be called with new items, unless the queue becomes empty. The return
+ value of the function will be given to the deferreds waiting on the item,
+ exceptions will be passed to the deferreds as well.
+
+ This function should therefore be called whenever anything is added
+ to the queue.
+
+ If another callback is currently handling the queue then it will not be
+ invoked.
+ """
+ if room_id in self._currently_persisting_rooms:
+ return
+
+ self._currently_persisting_rooms.add(room_id)
+
+ async def handle_queue_loop() -> None:
+ try:
+ queue = self._get_drainining_queue(room_id)
+ for item in queue:
+ try:
+ with opentracing.start_active_span_follows_from(
+ "persist_event_batch",
+ item.parent_opentracing_span_contexts,
+ inherit_force_tracing=True,
+ ) as scope:
+ if scope:
+ item.opentracing_span_context = scope.span.context
+
+ ret = await self._per_item_callback(
+ item.events_and_contexts, item.backfilled
+ )
+ except Exception:
+ with PreserveLoggingContext():
+ item.deferred.errback()
+ else:
+ with PreserveLoggingContext():
+ item.deferred.callback(ret)
+ finally:
+ remaining_queue = self._event_persist_queues.pop(room_id, None)
+ if remaining_queue:
+ self._event_persist_queues[room_id] = remaining_queue
+ self._currently_persisting_rooms.discard(room_id)
+
+ # set handle_queue_loop off in the background
+ run_as_background_process("persist_events", handle_queue_loop)
+
+ def _get_drainining_queue(
+ self, room_id: str
+ ) -> Generator[_EventPersistQueueItem, None, None]:
+ queue = self._event_persist_queues.setdefault(room_id, deque())
+
+ try:
+ while True:
+ yield queue.popleft()
+ except IndexError:
+ # Queue has been drained.
+ pass
+
+
+class EventsPersistenceStorageController:
+ """High level interface for handling persisting newly received events.
+
+ Takes care of batching up events by room, and calculating the necessary
+ current state and forward extremity changes.
+ """
+
+ def __init__(self, hs: "HomeServer", stores: Databases):
+ # We ultimately want to split out the state store from the main store,
+ # so we use separate variables here even though they point to the same
+ # store for now.
+ self.main_store = stores.main
+ self.state_store = stores.state
+
+ assert stores.persist_events
+ self.persist_events_store = stores.persist_events
+
+ self._clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
+ self.is_mine_id = hs.is_mine_id
+ self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch)
+ self._state_resolution_handler = hs.get_state_resolution_handler()
+
+ @opentracing.trace
+ async def persist_events(
+ self,
+ events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
+ backfilled: bool = False,
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
+ """
+ Write events to the database
+ Args:
+ events_and_contexts: list of tuples of (event, context)
+ backfilled: Whether the results are retrieved from federation
+ via backfill or not. Used to determine if they're "new" events
+ which might update the current state etc.
+
+ Returns:
+ List of events persisted, the current position room stream position.
+ The list of events persisted may not be the same as those passed in
+ if they were deduplicated due to an event already existing that
+ matched the transaction ID; the existing event is returned in such
+ a case.
+ """
+ partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
+ for event, ctx in events_and_contexts:
+ partitioned.setdefault(event.room_id, []).append((event, ctx))
+
+ async def enqueue(
+ item: Tuple[str, List[Tuple[EventBase, EventContext]]]
+ ) -> Dict[str, str]:
+ room_id, evs_ctxs = item
+ return await self._event_persist_queue.add_to_queue(
+ room_id, evs_ctxs, backfilled=backfilled
+ )
+
+ ret_vals = await yieldable_gather_results(enqueue, partitioned.items())
+
+ # Each call to add_to_queue returns a map from event ID to existing event ID if
+ # the event was deduplicated. (The dict may also include other entries if
+ # the event was persisted in a batch with other events).
+ #
+ # Since we use `yieldable_gather_results` we need to merge the returned list
+ # of dicts into one.
+ replaced_events: Dict[str, str] = {}
+ for d in ret_vals:
+ replaced_events.update(d)
+
+ events = []
+ for event, _ in events_and_contexts:
+ existing_event_id = replaced_events.get(event.event_id)
+ if existing_event_id:
+ events.append(await self.main_store.get_event(existing_event_id))
+ else:
+ events.append(event)
+
+ return (
+ events,
+ self.main_store.get_room_max_token(),
+ )
+
+ @opentracing.trace
+ async def persist_event(
+ self, event: EventBase, context: EventContext, backfilled: bool = False
+ ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
+ """
+ Returns:
+ The event, stream ordering of `event`, and the stream ordering of the
+ latest persisted event. The returned event may not match the given
+ event if it was deduplicated due to an existing event matching the
+ transaction ID.
+ """
+ # add_to_queue returns a map from event ID to existing event ID if the
+ # event was deduplicated. (The dict may also include other entries if
+ # the event was persisted in a batch with other events.)
+ replaced_events = await self._event_persist_queue.add_to_queue(
+ event.room_id, [(event, context)], backfilled=backfilled
+ )
+ replaced_event = replaced_events.get(event.event_id)
+ if replaced_event:
+ event = await self.main_store.get_event(replaced_event)
+
+ event_stream_id = event.internal_metadata.stream_ordering
+ # stream ordering should have been assigned by now
+ assert event_stream_id
+
+ pos = PersistedEventPosition(self._instance_name, event_stream_id)
+ return event, pos, self.main_store.get_room_max_token()
+
+ async def update_current_state(self, room_id: str) -> None:
+ """Recalculate the current state for a room, and persist it"""
+ state = await self._calculate_current_state(room_id)
+ delta = await self._calculate_state_delta(room_id, state)
+
+ # TODO(faster_joins): get a real stream ordering, to make this work correctly
+ # across workers.
+ #
+ # TODO(faster_joins): this can race against event persistence, in which case we
+ # will end up with incorrect state. Perhaps we should make this a job we
+ # farm out to the event persister, somehow.
+ stream_id = self.main_store.get_room_max_stream_ordering()
+ await self.persist_events_store.update_current_state(room_id, delta, stream_id)
+
+ async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
+ """Calculate the current state of a room, based on the forward extremities
+
+ Args:
+ room_id: room for which to calculate current state
+
+ Returns:
+ map from (type, state_key) to event id for the current state in the room
+ """
+ latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id)
+ state_groups = set(
+ (
+ await self.main_store._get_state_group_for_events(latest_event_ids)
+ ).values()
+ )
+
+ state_maps_by_state_group = await self.state_store._get_state_for_groups(
+ state_groups
+ )
+
+ if len(state_groups) == 1:
+ # If there is only one state group, then we know what the current
+ # state is.
+ return state_maps_by_state_group[state_groups.pop()]
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
+ logger.debug("calling resolve_state_groups from preserve_events")
+
+ # Avoid a circular import.
+ from synapse.state import StateResolutionStore
+
+ room_version = await self.main_store.get_room_version_id(room_id)
+ res = await self._state_resolution_handler.resolve_state_groups(
+ room_id,
+ room_version,
+ state_maps_by_state_group,
+ event_map=None,
+ state_res_store=StateResolutionStore(self.main_store),
+ )
+
+ return res.state
+
+ async def _persist_event_batch(
+ self,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ backfilled: bool = False,
+ ) -> Dict[str, str]:
+ """Callback for the _event_persist_queue
+
+ Calculates the change to current state and forward extremities, and
+ persists the given events and with those updates.
+
+ Returns:
+ A dictionary of event ID to event ID we didn't persist as we already
+ had another event persisted with the same TXN ID.
+ """
+ replaced_events: Dict[str, str] = {}
+ if not events_and_contexts:
+ return replaced_events
+
+ # Check if any of the events have a transaction ID that has already been
+ # persisted, and if so we don't persist it again.
+ #
+ # We should have checked this a long time before we get here, but it's
+ # possible that different send event requests race in such a way that
+ # they both pass the earlier checks. Checking here isn't racey as we can
+ # have only one `_persist_events` per room being called at a time.
+ replaced_events = await self.main_store.get_already_persisted_events(
+ (event for event, _ in events_and_contexts)
+ )
+
+ if replaced_events:
+ events_and_contexts = [
+ (e, ctx)
+ for e, ctx in events_and_contexts
+ if e.event_id not in replaced_events
+ ]
+
+ if not events_and_contexts:
+ return replaced_events
+
+ chunks = [
+ events_and_contexts[x : x + 100]
+ for x in range(0, len(events_and_contexts), 100)
+ ]
+
+ for chunk in chunks:
+ # We can't easily parallelize these since different chunks
+ # might contain the same event. :(
+
+ # NB: Assumes that we are only persisting events for one room
+ # at a time.
+
+ # map room_id->set[event_ids] giving the new forward
+ # extremities in each room
+ new_forward_extremities: Dict[str, Set[str]] = {}
+
+ # map room_id->(to_delete, to_insert) where to_delete is a list
+ # of type/state keys to remove from current state, and to_insert
+ # is a map (type,key)->event_id giving the state delta in each
+ # room
+ state_delta_for_room: Dict[str, DeltaState] = {}
+
+ # Set of remote users which were in rooms the server has left. We
+ # should check if we still share any rooms and if not we mark their
+ # device lists as stale.
+ potentially_left_users: Set[str] = set()
+
+ if not backfilled:
+ with Measure(self._clock, "_calculate_state_and_extrem"):
+ # Work out the new "current state" for each room.
+ # We do this by working out what the new extremities are and then
+ # calculating the state from that.
+ events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
+ for event, context in chunk:
+ events_by_room.setdefault(event.room_id, []).append(
+ (event, context)
+ )
+
+ for room_id, ev_ctx_rm in events_by_room.items():
+ latest_event_ids = set(
+ await self.main_store.get_latest_event_ids_in_room(room_id)
+ )
+ new_latest_event_ids = await self._calculate_new_extremities(
+ room_id, ev_ctx_rm, latest_event_ids
+ )
+
+ if new_latest_event_ids == latest_event_ids:
+ # No change in extremities, so no change in state
+ continue
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremities[room_id] = new_latest_event_ids
+
+ len_1 = (
+ len(latest_event_ids) == 1
+ and len(new_latest_event_ids) == 1
+ )
+ if len_1:
+ all_single_prev_not_state = all(
+ len(event.prev_event_ids()) == 1
+ and not event.is_state()
+ for event, ctx in ev_ctx_rm
+ )
+ # Don't bother calculating state if they're just
+ # a long chain of single ancestor non-state events.
+ if all_single_prev_not_state:
+ continue
+
+ state_delta_counter.inc()
+ if len(new_latest_event_ids) == 1:
+ state_delta_single_event_counter.inc()
+
+ # This is a fairly handwavey check to see if we could
+ # have guessed what the delta would have been when
+ # processing one of these events.
+ # What we're interested in is if the latest extremities
+ # were the same when we created the event as they are
+ # now. When this server creates a new event (as opposed
+ # to receiving it over federation) it will use the
+ # forward extremities as the prev_events, so we can
+ # guess this by looking at the prev_events and checking
+ # if they match the current forward extremities.
+ for ev, _ in ev_ctx_rm:
+ prev_event_ids = set(ev.prev_event_ids())
+ if latest_event_ids == prev_event_ids:
+ state_delta_reuse_delta_counter.inc()
+ break
+
+ logger.debug("Calculating state delta for room %s", room_id)
+ with Measure(
+ self._clock, "persist_events.get_new_state_after_events"
+ ):
+ res = await self._get_new_state_after_events(
+ room_id,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
+ )
+ current_state, delta_ids, new_latest_event_ids = res
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremities[room_id] = new_latest_event_ids
+
+ # If either are not None then there has been a change,
+ # and we need to work out the delta (or use that
+ # given)
+ delta = None
+ if delta_ids is not None:
+ # If there is a delta we know that we've
+ # only added or replaced state, never
+ # removed keys entirely.
+ delta = DeltaState([], delta_ids)
+ elif current_state is not None:
+ with Measure(
+ self._clock, "persist_events.calculate_state_delta"
+ ):
+ delta = await self._calculate_state_delta(
+ room_id, current_state
+ )
+
+ if delta:
+ # If we have a change of state then lets check
+ # whether we're actually still a member of the room,
+ # or if our last user left. If we're no longer in
+ # the room then we delete the current state and
+ # extremities.
+ is_still_joined = await self._is_server_still_joined(
+ room_id,
+ ev_ctx_rm,
+ delta,
+ current_state,
+ potentially_left_users,
+ )
+ if not is_still_joined:
+ logger.info("Server no longer in room %s", room_id)
+ latest_event_ids = set()
+ current_state = {}
+ delta.no_longer_in_room = True
+
+ state_delta_for_room[room_id] = delta
+
+ await self.persist_events_store._persist_events_and_state_updates(
+ chunk,
+ state_delta_for_room=state_delta_for_room,
+ new_forward_extremities=new_forward_extremities,
+ use_negative_stream_ordering=backfilled,
+ inhibit_local_membership_updates=backfilled,
+ )
+
+ await self._handle_potentially_left_users(potentially_left_users)
+
+ return replaced_events
+
+ async def _calculate_new_extremities(
+ self,
+ room_id: str,
+ event_contexts: List[Tuple[EventBase, EventContext]],
+ latest_event_ids: Collection[str],
+ ) -> Set[str]:
+ """Calculates the new forward extremities for a room given events to
+ persist.
+
+ Assumes that we are only persisting events for one room at a time.
+ """
+
+ # we're only interested in new events which aren't outliers and which aren't
+ # being rejected.
+ new_events = [
+ event
+ for event, ctx in event_contexts
+ if not event.internal_metadata.is_outlier()
+ and not ctx.rejected
+ and not event.internal_metadata.is_soft_failed()
+ ]
+
+ latest_event_ids = set(latest_event_ids)
+
+ # start with the existing forward extremities
+ result = set(latest_event_ids)
+
+ # add all the new events to the list
+ result.update(event.event_id for event in new_events)
+
+ # Now remove all events which are prev_events of any of the new events
+ result.difference_update(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+
+ # Remove any events which are prev_events of any existing events.
+ existing_prevs: Collection[
+ str
+ ] = await self.persist_events_store._get_events_which_are_prevs(result)
+ result.difference_update(existing_prevs)
+
+ # Finally handle the case where the new events have soft-failed prev
+ # events. If they do we need to remove them and their prev events,
+ # otherwise we end up with dangling extremities.
+ existing_prevs = await self.persist_events_store._get_prevs_before_rejected(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+ result.difference_update(existing_prevs)
+
+ # We only update metrics for events that change forward extremities
+ # (e.g. we ignore backfill/outliers/etc)
+ if result != latest_event_ids:
+ forward_extremities_counter.observe(len(result))
+ stale = latest_event_ids & result
+ stale_forward_extremities_counter.observe(len(stale))
+
+ return result
+
+ async def _get_new_state_after_events(
+ self,
+ room_id: str,
+ events_context: List[Tuple[EventBase, EventContext]],
+ old_latest_event_ids: Set[str],
+ new_latest_event_ids: Set[str],
+ ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
+ """Calculate the current state dict after adding some new events to
+ a room
+
+ Args:
+ room_id:
+ room to which the events are being added. Used for logging etc
+
+ events_context:
+ events and contexts which are being added to the room
+
+ old_latest_event_ids:
+ the old forward extremities for the room.
+
+ new_latest_event_ids :
+ the new forward extremities for the room.
+
+ Returns:
+ Returns a tuple of two state maps and a set of new forward
+ extremities.
+
+ The first state map is the full new current state and the second
+ is the delta to the existing current state. If both are None then
+ there has been no change. Either or neither can be None if there
+ has been a change.
+
+ The function may prune some old entries from the set of new
+ forward extremities if it's safe to do so.
+
+ If there has been a change then we only return the delta if its
+ already been calculated. Conversely if we do know the delta then
+ the new current state is only returned if we've already calculated
+ it.
+ """
+ # Map from (prev state group, new state group) -> delta state dict
+ state_group_deltas = {}
+
+ for ev, ctx in events_context:
+ if ctx.state_group is None:
+ # This should only happen for outlier events.
+ if not ev.internal_metadata.is_outlier():
+ raise Exception(
+ "Context for new event %s has no state "
+ "group" % (ev.event_id,)
+ )
+ continue
+
+ if ctx.prev_group:
+ state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+
+ # We need to map the event_ids to their state groups. First, let's
+ # check if the event is one we're persisting, in which case we can
+ # pull the state group from its context.
+ # Otherwise we need to pull the state group from the database.
+
+ # Set of events we need to fetch groups for. (We know none of the old
+ # extremities are going to be in events_context).
+ missing_event_ids = set(old_latest_event_ids)
+
+ event_id_to_state_group = {}
+ for event_id in new_latest_event_ids:
+ # First search in the list of new events we're adding.
+ for ev, ctx in events_context:
+ if event_id == ev.event_id and ctx.state_group is not None:
+ event_id_to_state_group[event_id] = ctx.state_group
+ break
+ else:
+ # If we couldn't find it, then we'll need to pull
+ # the state from the database
+ missing_event_ids.add(event_id)
+
+ if missing_event_ids:
+ # Now pull out the state groups for any missing events from DB
+ event_to_groups = await self.main_store._get_state_group_for_events(
+ missing_event_ids
+ )
+ event_id_to_state_group.update(event_to_groups)
+
+ # State groups of old_latest_event_ids
+ old_state_groups = {
+ event_id_to_state_group[evid] for evid in old_latest_event_ids
+ }
+
+ # State groups of new_latest_event_ids
+ new_state_groups = {
+ event_id_to_state_group[evid] for evid in new_latest_event_ids
+ }
+
+ # If they old and new groups are the same then we don't need to do
+ # anything.
+ if old_state_groups == new_state_groups:
+ return None, None, new_latest_event_ids
+
+ if len(new_state_groups) == 1 and len(old_state_groups) == 1:
+ # If we're going from one state group to another, lets check if
+ # we have a delta for that transition. If we do then we can just
+ # return that.
+
+ new_state_group = next(iter(new_state_groups))
+ old_state_group = next(iter(old_state_groups))
+
+ delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
+ if delta_ids is not None:
+ # We have a delta from the existing to new current state,
+ # so lets just return that.
+ return None, delta_ids, new_latest_event_ids
+
+ # Now that we have calculated new_state_groups we need to get
+ # their state IDs so we can resolve to a single state set.
+ state_groups_map = await self.state_store._get_state_for_groups(
+ new_state_groups
+ )
+
+ if len(new_state_groups) == 1:
+ # If there is only one state group, then we know what the current
+ # state is.
+ return state_groups_map[new_state_groups.pop()], None, new_latest_event_ids
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
+
+ state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
+
+ events_map = {ev.event_id: ev for ev, _ in events_context}
+
+ # We need to get the room version, which is in the create event.
+ # Normally that'd be in the database, but its also possible that we're
+ # currently trying to persist it.
+ room_version = None
+ for ev, _ in events_context:
+ if ev.type == EventTypes.Create and ev.state_key == "":
+ room_version = ev.content.get("room_version", "1")
+ break
+
+ if not room_version:
+ room_version = await self.main_store.get_room_version_id(room_id)
+
+ logger.debug("calling resolve_state_groups from preserve_events")
+
+ # Avoid a circular import.
+ from synapse.state import StateResolutionStore
+
+ res = await self._state_resolution_handler.resolve_state_groups(
+ room_id,
+ room_version,
+ state_groups,
+ events_map,
+ state_res_store=StateResolutionStore(self.main_store),
+ )
+
+ state_resolutions_during_persistence.inc()
+
+ # If the returned state matches the state group of one of the new
+ # forward extremities then we check if we are able to prune some state
+ # extremities.
+ if res.state_group and res.state_group in new_state_groups:
+ new_latest_event_ids = await self._prune_extremities(
+ room_id,
+ new_latest_event_ids,
+ res.state_group,
+ event_id_to_state_group,
+ events_context,
+ )
+
+ return res.state, None, new_latest_event_ids
+
+ async def _prune_extremities(
+ self,
+ room_id: str,
+ new_latest_event_ids: Set[str],
+ resolved_state_group: int,
+ event_id_to_state_group: Dict[str, int],
+ events_context: List[Tuple[EventBase, EventContext]],
+ ) -> Set[str]:
+ """See if we can prune any of the extremities after calculating the
+ resolved state.
+ """
+ potential_times_prune_extremities.inc()
+
+ # We keep all the extremities that have the same state group, and
+ # see if we can drop the others.
+ new_new_extrems = {
+ e
+ for e in new_latest_event_ids
+ if event_id_to_state_group[e] == resolved_state_group
+ }
+
+ dropped_extrems = set(new_latest_event_ids) - new_new_extrems
+
+ logger.debug("Might drop extremities: %s", dropped_extrems)
+
+ # We only drop events from the extremities list if:
+ # 1. we're not currently persisting them;
+ # 2. they're not our own events (or are dummy events); and
+ # 3. they're either:
+ # 1. over N hours old and more than N events ago (we use depth to
+ # calculate); or
+ # 2. we are persisting an event from the same domain and more than
+ # M events ago.
+ #
+ # The idea is that we don't want to drop events that are "legitimate"
+ # extremities (that we would want to include as prev events), only
+ # "stuck" extremities that are e.g. due to a gap in the graph.
+ #
+ # Note that we either drop all of them or none of them. If we only drop
+ # some of the events we don't know if state res would come to the same
+ # conclusion.
+
+ for ev, _ in events_context:
+ if ev.event_id in dropped_extrems:
+ logger.debug(
+ "Not dropping extremities: %s is being persisted", ev.event_id
+ )
+ return new_latest_event_ids
+
+ dropped_events = await self.main_store.get_events(
+ dropped_extrems,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.as_is,
+ )
+
+ new_senders = {get_domain_from_id(e.sender) for e, _ in events_context}
+
+ one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ current_depth = max(e.depth for e, _ in events_context)
+ for event in dropped_events.values():
+ # If the event is a local dummy event then we should check it
+ # doesn't reference any local events, as we want to reference those
+ # if we send any new events.
+ #
+ # Note we do this recursively to handle the case where a dummy event
+ # references a dummy event that only references remote events.
+ #
+ # Ideally we'd figure out a way of still being able to drop old
+ # dummy events that reference local events, but this is good enough
+ # as a first cut.
+ events_to_check: Collection[EventBase] = [event]
+ while events_to_check:
+ new_events: Set[str] = set()
+ for event_to_check in events_to_check:
+ if self.is_mine_id(event_to_check.sender):
+ if event_to_check.type != EventTypes.Dummy:
+ logger.debug("Not dropping own event")
+ return new_latest_event_ids
+ new_events.update(event_to_check.prev_event_ids())
+
+ prev_events = await self.main_store.get_events(
+ new_events,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.as_is,
+ )
+ events_to_check = prev_events.values()
+
+ if (
+ event.origin_server_ts < one_day_ago
+ and event.depth < current_depth - 100
+ ):
+ continue
+
+ # We can be less conservative about dropping extremities from the
+ # same domain, though we do want to wait a little bit (otherwise
+ # we'll immediately remove all extremities from a given server).
+ if (
+ get_domain_from_id(event.sender) in new_senders
+ and event.depth < current_depth - 20
+ ):
+ continue
+
+ logger.debug(
+ "Not dropping as too new and not in new_senders: %s",
+ new_senders,
+ )
+
+ return new_latest_event_ids
+
+ times_pruned_extremities.inc()
+
+ logger.info(
+ "Pruning forward extremities in room %s: from %s -> %s",
+ room_id,
+ new_latest_event_ids,
+ new_new_extrems,
+ )
+ return new_new_extrems
+
+ async def _calculate_state_delta(
+ self, room_id: str, current_state: StateMap[str]
+ ) -> DeltaState:
+ """Calculate the new state deltas for a room.
+
+ Assumes that we are only persisting events for one room at a time.
+ """
+ existing_state = await self.main_store.get_current_state_ids(room_id)
+
+ to_delete = [key for key in existing_state if key not in current_state]
+
+ to_insert = {
+ key: ev_id
+ for key, ev_id in current_state.items()
+ if ev_id != existing_state.get(key)
+ }
+
+ return DeltaState(to_delete=to_delete, to_insert=to_insert)
+
+ async def _is_server_still_joined(
+ self,
+ room_id: str,
+ ev_ctx_rm: List[Tuple[EventBase, EventContext]],
+ delta: DeltaState,
+ current_state: Optional[StateMap[str]],
+ potentially_left_users: Set[str],
+ ) -> bool:
+ """Check if the server will still be joined after the given events have
+ been persised.
+
+ Args:
+ room_id
+ ev_ctx_rm
+ delta: The delta of current state between what is in the database
+ and what the new current state will be.
+ current_state: The new current state if it already been calculated,
+ otherwise None.
+ potentially_left_users: If the server has left the room, then joined
+ remote users will be added to this set to indicate that the
+ server may no longer be sharing a room with them.
+ """
+
+ if not any(
+ self.is_mine_id(state_key)
+ for typ, state_key in itertools.chain(delta.to_delete, delta.to_insert)
+ if typ == EventTypes.Member
+ ):
+ # There have been no changes to membership of our users, so nothing
+ # has changed and we assume we're still in the room.
+ return True
+
+ # Check if any of the given events are a local join that appear in the
+ # current state
+ events_to_check = [] # Event IDs that aren't an event we're persisting
+ for (typ, state_key), event_id in delta.to_insert.items():
+ if typ != EventTypes.Member or not self.is_mine_id(state_key):
+ continue
+
+ for event, _ in ev_ctx_rm:
+ if event_id == event.event_id:
+ if event.membership == Membership.JOIN:
+ return True
+
+ # The event is not in `ev_ctx_rm`, so we need to pull it out of
+ # the DB.
+ events_to_check.append(event_id)
+
+ # Check if any of the changes that we don't have events for are joins.
+ if events_to_check:
+ members = await self.main_store.get_membership_from_event_ids(
+ events_to_check
+ )
+ is_still_joined = any(
+ member and member.membership == Membership.JOIN
+ for member in members.values()
+ )
+ if is_still_joined:
+ return True
+
+ # None of the new state events are local joins, so we check the database
+ # to see if there are any other local users in the room. We ignore users
+ # whose state has changed as we've already their new state above.
+ users_to_ignore = [
+ state_key
+ for typ, state_key in itertools.chain(delta.to_insert, delta.to_delete)
+ if typ == EventTypes.Member and self.is_mine_id(state_key)
+ ]
+
+ if await self.main_store.is_local_host_in_room_ignoring_users(
+ room_id, users_to_ignore
+ ):
+ return True
+
+ # The server will leave the room, so we go and find out which remote
+ # users will still be joined when we leave.
+ if current_state is None:
+ current_state = await self.main_store.get_current_state_ids(room_id)
+ current_state = dict(current_state)
+ for key in delta.to_delete:
+ current_state.pop(key, None)
+
+ current_state.update(delta.to_insert)
+
+ remote_event_ids = [
+ event_id
+ for (
+ typ,
+ state_key,
+ ), event_id in current_state.items()
+ if typ == EventTypes.Member and not self.is_mine_id(state_key)
+ ]
+ members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
+ potentially_left_users.update(
+ member.user_id
+ for member in members.values()
+ if member and member.membership == Membership.JOIN
+ )
+
+ return False
+
+ async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
+ """Given a set of remote users check if the server still shares a room with
+ them. If not then mark those users' device cache as stale.
+ """
+
+ if not user_ids:
+ return
+
+ joined_users = await self.main_store.get_users_server_still_shares_room_with(
+ user_ids
+ )
+ left_users = user_ids - joined_users
+
+ for user_id in left_users:
+ await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py
new file mode 100644
index 0000000000..9ca50d6a09
--- /dev/null
+++ b/synapse/storage/controllers/purge_events.py
@@ -0,0 +1,112 @@
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+from typing import TYPE_CHECKING, Set
+
+from synapse.storage.databases import Databases
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeEventsStorageController:
+ """High level interface for purging rooms and event history."""
+
+ def __init__(self, hs: "HomeServer", stores: Databases):
+ self.stores = stores
+
+ async def purge_room(self, room_id: str) -> None:
+ """Deletes all record of a room"""
+
+ state_groups_to_delete = await self.stores.main.purge_room(room_id)
+ await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
+
+ async def purge_history(
+ self, room_id: str, token: str, delete_local_events: bool
+ ) -> None:
+ """Deletes room history before a certain point
+
+ Args:
+ room_id: The room ID
+
+ token: A topological token to delete events before
+
+ delete_local_events:
+ if True, we will delete local events as well as remote ones
+ (instead of just marking them as outliers and deleting their
+ state groups).
+ """
+ state_groups = await self.stores.main.purge_history(
+ room_id, token, delete_local_events
+ )
+
+ logger.info("[purge] finding state groups that can be deleted")
+
+ sg_to_delete = await self._find_unreferenced_groups(state_groups)
+
+ await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
+
+ async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
+ """Used when purging history to figure out which state groups can be
+ deleted.
+
+ Args:
+ state_groups: Set of state groups referenced by events
+ that are going to be deleted.
+
+ Returns:
+ The set of state groups that can be deleted.
+ """
+ # Set of events that we have found to be referenced by events
+ referenced_groups = set()
+
+ # Set of state groups we've already seen
+ state_groups_seen = set(state_groups)
+
+ # Set of state groups to handle next.
+ next_to_search = set(state_groups)
+ while next_to_search:
+ # We bound size of groups we're looking up at once, to stop the
+ # SQL query getting too big
+ if len(next_to_search) < 100:
+ current_search = next_to_search
+ next_to_search = set()
+ else:
+ current_search = set(itertools.islice(next_to_search, 100))
+ next_to_search -= current_search
+
+ referenced = await self.stores.main.get_referenced_state_groups(
+ current_search
+ )
+ referenced_groups |= referenced
+
+ # We don't continue iterating up the state group graphs for state
+ # groups that are referenced.
+ current_search -= referenced
+
+ edges = await self.stores.state.get_previous_state_groups(current_search)
+
+ prevs = set(edges.values())
+ # We don't bother re-handling groups we've already seen
+ prevs -= state_groups_seen
+ next_to_search |= prevs
+ state_groups_seen |= prevs
+
+ to_delete = state_groups_seen - referenced_groups
+
+ return to_delete
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
new file mode 100644
index 0000000000..0f09953086
--- /dev/null
+++ b/synapse/storage/controllers/state.py
@@ -0,0 +1,351 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+)
+
+from synapse.events import EventBase
+from synapse.storage.state import StateFilter
+from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
+from synapse.types import MutableStateMap, StateMap
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.storage.databases import Databases
+
+logger = logging.getLogger(__name__)
+
+
+class StateGroupStorageController:
+ """High level interface to fetching state for event."""
+
+ def __init__(self, hs: "HomeServer", stores: "Databases"):
+ self._is_mine_id = hs.is_mine_id
+ self.stores = stores
+ self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
+
+ def notify_event_un_partial_stated(self, event_id: str) -> None:
+ self._partial_state_events_tracker.notify_un_partial_stated(event_id)
+
+ async def get_state_group_delta(
+ self, state_group: int
+ ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Args:
+ state_group: The state group used to retrieve state deltas.
+
+ Returns:
+ A tuple of the previous group and a state map of the event IDs which
+ make up the delta between the old and new state groups.
+ """
+
+ state_group_delta = await self.stores.state.get_state_group_delta(state_group)
+ return state_group_delta.prev_group, state_group_delta.delta_ids
+
+ async def get_state_groups_ids(
+ self, _room_id: str, event_ids: Collection[str]
+ ) -> Dict[int, MutableStateMap[str]]:
+ """Get the event IDs of all the state for the state groups for the given events
+
+ Args:
+ _room_id: id of the room for these events
+ event_ids: ids of the events
+
+ Returns:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
+ """
+ if not event_ids:
+ return {}
+
+ event_to_groups = await self.get_state_group_for_events(event_ids)
+
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(groups)
+
+ return group_to_state
+
+ async def get_state_ids_for_group(
+ self, state_group: int, state_filter: Optional[StateFilter] = None
+ ) -> StateMap[str]:
+ """Get the event IDs of all the state in the given state group
+
+ Args:
+ state_group: A state group for which we want to get the state IDs.
+ state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
+
+ Returns:
+ Resolves to a map of (type, state_key) -> event_id
+ """
+ group_to_state = await self.get_state_for_groups((state_group,), state_filter)
+
+ return group_to_state[state_group]
+
+ async def get_state_groups(
+ self, room_id: str, event_ids: Collection[str]
+ ) -> Dict[int, List[EventBase]]:
+ """Get the state groups for the given list of event_ids
+
+ Args:
+ room_id: ID of the room for these events.
+ event_ids: The event IDs to retrieve state for.
+
+ Returns:
+ dict of state_group_id -> list of state events.
+ """
+ if not event_ids:
+ return {}
+
+ group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
+
+ state_event_map = await self.stores.main.get_events(
+ [
+ ev_id
+ for group_ids in group_to_ids.values()
+ for ev_id in group_ids.values()
+ ],
+ get_prev_content=False,
+ )
+
+ return {
+ group: [
+ state_event_map[v]
+ for v in event_id_map.values()
+ if v in state_event_map
+ ]
+ for group, event_id_map in group_to_ids.items()
+ }
+
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ) -> Awaitable[Dict[int, StateMap[str]]]:
+ """Returns the state groups for a given set of groups, filtering on
+ types of state events.
+
+ Args:
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
+ from the database.
+
+ Returns:
+ Dict of state group to state map.
+ """
+
+ return self.stores.state._get_state_groups_from_groups(groups, state_filter)
+
+ async def get_state_for_events(
+ self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
+ ) -> Dict[str, StateMap[EventBase]]:
+ """Given a list of event_ids and type tuples, return a list of state
+ dicts for each event.
+
+ Args:
+ event_ids: The events to fetch the state of.
+ state_filter: The state filter used to fetch state.
+
+ Returns:
+ A dict of (event_id) -> (type, state_key) -> [state_events]
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
+ """
+ await_full_state = True
+ if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ await_full_state = False
+
+ event_to_groups = await self.get_state_group_for_events(
+ event_ids, await_full_state=await_full_state
+ )
+
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(
+ groups, state_filter or StateFilter.all()
+ )
+
+ state_event_map = await self.stores.main.get_events(
+ [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+ get_prev_content=False,
+ )
+
+ event_to_state = {
+ event_id: {
+ k: state_event_map[v]
+ for k, v in group_to_state[group].items()
+ if v in state_event_map
+ }
+ for event_id, group in event_to_groups.items()
+ }
+
+ return {event: event_to_state[event] for event in event_ids}
+
+ async def get_state_ids_for_events(
+ self,
+ event_ids: Collection[str],
+ state_filter: Optional[StateFilter] = None,
+ ) -> Dict[str, StateMap[str]]:
+ """
+ Get the state dicts corresponding to a list of events, containing the event_ids
+ of the state events (as opposed to the events themselves)
+
+ Args:
+ event_ids: events whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
+
+ Returns:
+ A dict from event_id -> (type, state_key) -> event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
+ """
+ await_full_state = True
+ if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ await_full_state = False
+
+ event_to_groups = await self.get_state_group_for_events(
+ event_ids, await_full_state=await_full_state
+ )
+
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(
+ groups, state_filter or StateFilter.all()
+ )
+
+ event_to_state = {
+ event_id: group_to_state[group]
+ for event_id, group in event_to_groups.items()
+ }
+
+ return {event: event_to_state[event] for event in event_ids}
+
+ async def get_state_for_event(
+ self, event_id: str, state_filter: Optional[StateFilter] = None
+ ) -> StateMap[EventBase]:
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id: event whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
+
+ Returns:
+ A dict from (type, state_key) -> state_event
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
+ """
+ state_map = await self.get_state_for_events(
+ [event_id], state_filter or StateFilter.all()
+ )
+ return state_map[event_id]
+
+ async def get_state_ids_for_event(
+ self, event_id: str, state_filter: Optional[StateFilter] = None
+ ) -> StateMap[str]:
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id: event whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
+
+ Returns:
+ A dict from (type, state_key) -> state_event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
+ """
+ state_map = await self.get_state_ids_for_events(
+ [event_id], state_filter or StateFilter.all()
+ )
+ return state_map[event_id]
+
+ def get_state_for_groups(
+ self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
+ ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups: list of state groups for which we want to get the state.
+ state_filter: The state filter used to fetch state.
+ from the database.
+
+ Returns:
+ Dict of state group to state map.
+ """
+ return self.stores.state._get_state_for_groups(
+ groups, state_filter or StateFilter.all()
+ )
+
+ async def get_state_group_for_events(
+ self,
+ event_ids: Collection[str],
+ await_full_state: bool = True,
+ ) -> Mapping[str, int]:
+ """Returns mapping event_id -> state_group
+
+ Args:
+ event_ids: events to get state groups for
+ await_full_state: if true, will block if we do not yet have complete
+ state at these events.
+ """
+ if await_full_state:
+ await self._partial_state_events_tracker.await_full_state(event_ids)
+
+ return await self.stores.main._get_state_group_for_events(event_ids)
+
+ async def store_state_group(
+ self,
+ event_id: str,
+ room_id: str,
+ prev_group: Optional[int],
+ delta_ids: Optional[StateMap[str]],
+ current_state_ids: StateMap[str],
+ ) -> int:
+ """Store a new set of state, returning a newly assigned state group.
+
+ Args:
+ event_id: The event ID for which the state was calculated.
+ room_id: ID of the room for which the state was calculated.
+ prev_group: A previous state group for the room, optional.
+ delta_ids: The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids: The state to store. Map of (type, state_key)
+ to event_id.
+
+ Returns:
+ The state group ID
+ """
+ return await self.stores.state.store_state_group(
+ event_id, room_id, prev_group, delta_ids, current_state_ids
+ )
|