summary refs log tree commit diff
path: root/synapse/events/snapshot.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events/snapshot.py')
-rw-r--r--synapse/events/snapshot.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e0d82ad81c..a91a5d1e3c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap
 
 if TYPE_CHECKING:
     from synapse.storage.controllers import StorageControllers
+    from synapse.storage.databases import StateGroupDataStore
     from synapse.storage.databases.main import DataStore
     from synapse.types.state import StateFilter
 
@@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
     partial_state: bool
     state_map_before_event: Optional[StateMap[str]] = None
 
+    @classmethod
+    async def batch_persist_unpersisted_contexts(
+        cls,
+        events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]],
+        room_id: str,
+        last_known_state_group: int,
+        datastore: "StateGroupDataStore",
+    ) -> List[Tuple[EventBase, EventContext]]:
+        """
+        Takes a list of events and their associated unpersisted contexts and persists
+        the unpersisted contexts, returning a list of events and persisted contexts.
+        Note that all the events must be in a linear chain (ie a <- b <- c).
+
+        Args:
+            events_and_context: A list of events and their unpersisted contexts
+            room_id: the room_id for the events
+            last_known_state_group: the last persisted state group
+            datastore: a state datastore
+        """
+        amended_events_and_context = await datastore.store_state_deltas_for_batched(
+            events_and_context, room_id, last_known_state_group
+        )
+
+        events_and_persisted_context = []
+        for event, unpersisted_context in amended_events_and_context:
+            if event.is_state():
+                context = EventContext(
+                    storage=unpersisted_context._storage,
+                    state_group=unpersisted_context.state_group_after_event,
+                    state_group_before_event=unpersisted_context.state_group_before_event,
+                    state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+                    partial_state=unpersisted_context.partial_state,
+                    prev_group=unpersisted_context.state_group_before_event,
+                    delta_ids=unpersisted_context.state_delta_due_to_event,
+                )
+            else:
+                context = EventContext(
+                    storage=unpersisted_context._storage,
+                    state_group=unpersisted_context.state_group_after_event,
+                    state_group_before_event=unpersisted_context.state_group_before_event,
+                    state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+                    partial_state=unpersisted_context.partial_state,
+                    prev_group=unpersisted_context.prev_group_for_state_group_before_event,
+                    delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
+                )
+            events_and_persisted_context.append((event, context))
+        return events_and_persisted_context
+
     async def get_prev_state_ids(
         self, state_filter: Optional["StateFilter"] = None
     ) -> StateMap[str]: