diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 609a2b88bf..afbc85ad0c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -400,14 +400,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[StateMap[str]],
- current_state_ids: StateMap[str],
+ current_state_ids: Optional[StateMap[str]],
) -> int:
"""Store a new set of state, returning a newly assigned state group.
+ At least one of `current_state_ids` and `prev_group` must be provided. Whenever
+ `prev_group` is not None, `delta_ids` must also not be None.
+
Args:
event_id: The event ID for which the state was calculated
room_id
- prev_group: A previous state group for the room, optional.
+ prev_group: A previous state group for the room.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
@@ -418,10 +421,41 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
The state group ID
"""
- def _store_state_group_txn(txn: LoggingTransaction) -> int:
- if current_state_ids is None:
- # AFAIK, this can never happen
- raise Exception("current_state_ids cannot be None")
+ if prev_group is None and current_state_ids is None:
+ raise Exception("current_state_ids and prev_group can't both be None")
+
+ if prev_group is not None and delta_ids is None:
+ raise Exception("delta_ids is None when prev_group is not None")
+
+ def insert_delta_group_txn(
+ txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str]
+ ) -> Optional[int]:
+ """Try and persist the new group as a delta.
+
+ Requires that we have the state as a delta from a previous state group.
+
+ Returns:
+ The state group if successfully created, or None if the state
+ needs to be persisted as a full state.
+ """
+ is_in_db = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ # if the chain of state group deltas is going too long, we fall back to
+ # persisting a complete state group.
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+ if potential_hops >= MAX_STATE_DELTA_HOPS:
+ return None
state_group = self._state_group_seq_gen.get_next_id_txn(txn)
@@ -431,51 +465,45 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
values={"id": state_group, "room_id": room_id, "event_id": event_id},
)
- # We persist as a delta if we can, while also ensuring the chain
- # of deltas isn't tooo long, as otherwise read performance degrades.
- if prev_group:
- is_in_db = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="state_groups",
- keyvalues={"id": prev_group},
- retcol="id",
- allow_none=True,
- )
- if not is_in_db:
- raise Exception(
- "Trying to persist state with unpersisted prev_group: %r"
- % (prev_group,)
- )
-
- potential_hops = self._count_state_group_hops_txn(txn, prev_group)
- if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- assert delta_ids is not None
-
- self.db_pool.simple_insert_txn(
- txn,
- table="state_group_edges",
- values={"state_group": state_group, "prev_state_group": prev_group},
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={"state_group": state_group, "prev_state_group": prev_group},
+ )
- self.db_pool.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- keys=("state_group", "room_id", "type", "state_key", "event_id"),
- values=[
- (state_group, room_id, key[0], key[1], state_id)
- for key, state_id in delta_ids.items()
- ],
- )
- else:
- self.db_pool.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- keys=("state_group", "room_id", "type", "state_key", "event_id"),
- values=[
- (state_group, room_id, key[0], key[1], state_id)
- for key, state_id in current_state_ids.items()
- ],
- )
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
+ values=[
+ (state_group, room_id, key[0], key[1], state_id)
+ for key, state_id in delta_ids.items()
+ ],
+ )
+
+ return state_group
+
+ def insert_full_state_txn(
+ txn: LoggingTransaction, current_state_ids: StateMap[str]
+ ) -> int:
+ """Persist the full state, returning the new state group."""
+ state_group = self._state_group_seq_gen.get_next_id_txn(txn)
+
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={"id": state_group, "room_id": room_id, "event_id": event_id},
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
+ values=[
+ (state_group, room_id, key[0], key[1], state_id)
+ for key, state_id in current_state_ids.items()
+ ],
+ )
# Prefill the state group caches with this group.
# It's fine to use the sequence like this as the state group map
@@ -491,7 +519,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self._state_group_members_cache.update,
self._state_group_members_cache.sequence,
key=state_group,
- value=dict(current_member_state_ids),
+ value=current_member_state_ids,
)
current_non_member_state_ids = {
@@ -503,13 +531,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self._state_group_cache.update,
self._state_group_cache.sequence,
key=state_group,
- value=dict(current_non_member_state_ids),
+ value=current_non_member_state_ids,
)
return state_group
+ if prev_group is not None:
+ state_group = await self.db_pool.runInteraction(
+ "store_state_group.insert_delta_group",
+ insert_delta_group_txn,
+ prev_group,
+ delta_ids,
+ )
+ if state_group is not None:
+ return state_group
+
+ # We're going to persist the state as a complete group rather than
+ # a delta, so first we need to ensure we have loaded the state map
+ # from the database.
+ if current_state_ids is None:
+ assert prev_group is not None
+ assert delta_ids is not None
+ groups = await self._get_state_for_groups([prev_group])
+ current_state_ids = dict(groups[prev_group])
+ current_state_ids.update(delta_ids)
+
return await self.db_pool.runInteraction(
- "store_state_group", _store_state_group_txn
+ "store_state_group.insert_full_state",
+ insert_full_state_txn,
+ current_state_ids,
)
async def purge_unreferenced_state_groups(
|