diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index f8cfcaca83..97474bc60d 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
import attr
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -404,6 +406,111 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types,
)
+ async def store_state_deltas_for_batched(
+ self,
+ events_and_context: List[Tuple[EventBase, EventContext]],
+ room_id: str,
+ prev_group: int,
+ ) -> List[int]:
+ """Generate and store state deltas for a group of events and contexts created to be
+ batch persisted.
+
+ Args:
+ events_and_context: the events to generate and store a state groups for
+ and their associated contexts
+ room_id: the id of the room the events were created for
+ prev_group: the state group of the last event persisted before the batched events
+ were created
+ """
+
+ def insert_deltas_group_txn(
+ txn: LoggingTransaction,
+ events_and_context: List[Tuple[EventBase, EventContext]],
+ prev_group: int,
+ ) -> List[int]:
+ """Generate and store state groups for the provided events and contexts.
+
+ Requires that we have the state as a delta from the last persisted state group.
+
+ Returns:
+ A list of state groups
+ """
+ is_in_db = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ num_state_groups = len(events_and_context)
+
+ state_groups = self._state_group_seq_gen.get_next_mult_txn(
+ txn, num_state_groups
+ )
+
+ index = 0
+ for event, context in events_and_context:
+ context._state_group = state_groups[index]
+ # The first prev_group will be the last persisted state group, which is passed in
+ # else it will be the group most recently assigned
+ if index > 0:
+ context.prev_group = state_groups[index - 1]
+ context.state_group_before_event = state_groups[index - 1]
+ else:
+ context.prev_group = prev_group
+ context.state_group_before_event = prev_group
+ context.delta_ids = {(event.type, event.state_key): event.event_id}
+ context._state_delta_due_to_event = {
+ (event.type, event.state_key): event.event_id
+ }
+ index += 1
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups",
+ keys=("id", "room_id", "event_id"),
+ values=[
+ (context._state_group, room_id, event.event_id)
+ for event, context in events_and_context
+ ],
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_group_edges",
+ keys=("state_group", "prev_state_group"),
+ values=[
+ (context._state_group, context.prev_group)
+ for _, context in events_and_context
+ ],
+ )
+
+ for _, context in events_and_context:
+ assert context.delta_ids is not None
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
+ values=[
+ (context._state_group, room_id, key[0], key[1], state_id)
+ for key, state_id in context.delta_ids.items()
+ ],
+ )
+ return state_groups
+
+ return await self.db_pool.runInteraction(
+ "store_state_deltas_for_batched.insert_deltas_group",
+ insert_deltas_group_txn,
+ events_and_context,
+ prev_group,
+ )
+
async def store_state_group(
self,
event_id: str,
@@ -413,10 +520,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
current_state_ids: Optional[StateMap[str]],
) -> int:
"""Store a new set of state, returning a newly assigned state group.
-
At least one of `current_state_ids` and `prev_group` must be provided. Whenever
`prev_group` is not None, `delta_ids` must also not be None.
-
Args:
event_id: The event ID for which the state was calculated
room_id
@@ -426,7 +531,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
`current_state_ids`.
current_state_ids: The state to store. Map of (type, state_key)
to event_id.
-
Returns:
The state group ID
"""
@@ -441,9 +545,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str]
) -> Optional[int]:
"""Try and persist the new group as a delta.
-
Requires that we have the state as a delta from a previous state group.
-
Returns:
The state group if successfully created, or None if the state
needs to be persisted as a full state.
|