diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e0d82ad81c..a91a5d1e3c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
from synapse.storage.controllers import StorageControllers
+ from synapse.storage.databases import StateGroupDataStore
from synapse.storage.databases.main import DataStore
from synapse.types.state import StateFilter
@@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
partial_state: bool
state_map_before_event: Optional[StateMap[str]] = None
+ @classmethod
+ async def batch_persist_unpersisted_contexts(
+ cls,
+ events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]],
+ room_id: str,
+ last_known_state_group: int,
+ datastore: "StateGroupDataStore",
+ ) -> List[Tuple[EventBase, EventContext]]:
+ """
+ Takes a list of events and their associated unpersisted contexts and persists
+ the unpersisted contexts, returning a list of events and persisted contexts.
+ Note that all the events must be in a linear chain (ie a <- b <- c).
+
+ Args:
+ events_and_context: A list of events and their unpersisted contexts
+ room_id: the room_id for the events
+ last_known_state_group: the last persisted state group
+ datastore: a state datastore
+ """
+ amended_events_and_context = await datastore.store_state_deltas_for_batched(
+ events_and_context, room_id, last_known_state_group
+ )
+
+ events_and_persisted_context = []
+ for event, unpersisted_context in amended_events_and_context:
+ if event.is_state():
+ context = EventContext(
+ storage=unpersisted_context._storage,
+ state_group=unpersisted_context.state_group_after_event,
+ state_group_before_event=unpersisted_context.state_group_before_event,
+ state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+ partial_state=unpersisted_context.partial_state,
+ prev_group=unpersisted_context.state_group_before_event,
+ delta_ids=unpersisted_context.state_delta_due_to_event,
+ )
+ else:
+ context = EventContext(
+ storage=unpersisted_context._storage,
+ state_group=unpersisted_context.state_group_after_event,
+ state_group_before_event=unpersisted_context.state_group_before_event,
+ state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+ partial_state=unpersisted_context.partial_state,
+ prev_group=unpersisted_context.prev_group_for_state_group_before_event,
+ delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
+ )
+ events_and_persisted_context.append((event, context))
+ return events_and_persisted_context
+
async def get_prev_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> StateMap[str]:
|