summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/state/__init__.py36
-rw-r--r--synapse/storage/controllers/state.py2
-rw-r--r--synapse/storage/databases/state/store.py156
3 files changed, 125 insertions, 69 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 56606e9afb..fcb7e829d4 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -298,12 +298,18 @@ class StateHandler:
 
             state_group_before_event_prev_group = entry.prev_group
             deltas_to_state_group_before_event = entry.delta_ids
+            state_ids_before_event = None
 
             # We make sure that we have a state group assigned to the state.
             if entry.state_group is None:
-                state_ids_before_event = await entry.get_state(
-                    self._state_storage_controller, StateFilter.all()
-                )
+                # store_state_group requires us to have either a previous state group
+                # (with deltas) or the complete state map. So, if we don't have a
+                # previous state group, load the complete state map now.
+                if state_group_before_event_prev_group is None:
+                    state_ids_before_event = await entry.get_state(
+                        self._state_storage_controller, StateFilter.all()
+                    )
+
                 state_group_before_event = (
                     await self._state_storage_controller.store_state_group(
                         event.event_id,
@@ -316,7 +322,6 @@ class StateHandler:
                 entry.state_group = state_group_before_event
             else:
                 state_group_before_event = entry.state_group
-                state_ids_before_event = None
 
         #
         # now if it's not a state event, we're done
@@ -336,19 +341,20 @@ class StateHandler:
         #
         # otherwise, we'll need to create a new state group for after the event
         #
-        if state_ids_before_event is None:
-            state_ids_before_event = await entry.get_state(
-                self._state_storage_controller, StateFilter.all()
-            )
 
         key = (event.type, event.state_key)
-        if key in state_ids_before_event:
-            replaces = state_ids_before_event[key]
-            if replaces != event.event_id:
-                event.unsigned["replaces_state"] = replaces
 
-        state_ids_after_event = dict(state_ids_before_event)
-        state_ids_after_event[key] = event.event_id
+        if state_ids_before_event is not None:
+            replaces = state_ids_before_event.get(key)
+        else:
+            replaces_state_map = await entry.get_state(
+                self._state_storage_controller, StateFilter.from_types([key])
+            )
+            replaces = replaces_state_map.get(key)
+
+        if replaces and replaces != event.event_id:
+            event.unsigned["replaces_state"] = replaces
+
         delta_ids = {key: event.event_id}
 
         state_group_after_event = (
@@ -357,7 +363,7 @@ class StateHandler:
                 event.room_id,
                 prev_group=state_group_before_event,
                 delta_ids=delta_ids,
-                current_state_ids=state_ids_after_event,
+                current_state_ids=None,
             )
         )
 
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index d3a44bc876..e08f956e6e 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -346,7 +346,7 @@ class StateStorageController:
         room_id: str,
         prev_group: Optional[int],
         delta_ids: Optional[StateMap[str]],
-        current_state_ids: StateMap[str],
+        current_state_ids: Optional[StateMap[str]],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 609a2b88bf..afbc85ad0c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -400,14 +400,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         room_id: str,
         prev_group: Optional[int],
         delta_ids: Optional[StateMap[str]],
-        current_state_ids: StateMap[str],
+        current_state_ids: Optional[StateMap[str]],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
+        At least one of `current_state_ids` and `prev_group` must be provided. Whenever
+        `prev_group` is not None, `delta_ids` must also not be None.
+
         Args:
             event_id: The event ID for which the state was calculated
             room_id
-            prev_group: A previous state group for the room, optional.
+            prev_group: A previous state group for the room.
             delta_ids: The delta between state at `prev_group` and
                 `current_state_ids`, if `prev_group` was given. Same format as
                 `current_state_ids`.
@@ -418,10 +421,41 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             The state group ID
         """
 
-        def _store_state_group_txn(txn: LoggingTransaction) -> int:
-            if current_state_ids is None:
-                # AFAIK, this can never happen
-                raise Exception("current_state_ids cannot be None")
+        if prev_group is None and current_state_ids is None:
+            raise Exception("current_state_ids and prev_group can't both be None")
+
+        if prev_group is not None and delta_ids is None:
+            raise Exception("delta_ids is None when prev_group is not None")
+
+        def insert_delta_group_txn(
+            txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str]
+        ) -> Optional[int]:
+            """Try and persist the new group as a delta.
+
+            Requires that we have the state as a delta from a previous state group.
+
+            Returns:
+                The state group if successfully created, or None if the state
+                needs to be persisted as a full state.
+            """
+            is_in_db = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="state_groups",
+                keyvalues={"id": prev_group},
+                retcol="id",
+                allow_none=True,
+            )
+            if not is_in_db:
+                raise Exception(
+                    "Trying to persist state with unpersisted prev_group: %r"
+                    % (prev_group,)
+                )
+
+            # if the chain of state group deltas is going too long, we fall back to
+            # persisting a complete state group.
+            potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+            if potential_hops >= MAX_STATE_DELTA_HOPS:
+                return None
 
             state_group = self._state_group_seq_gen.get_next_id_txn(txn)
 
@@ -431,51 +465,45 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 values={"id": state_group, "room_id": room_id, "event_id": event_id},
             )
 
-            # We persist as a delta if we can, while also ensuring the chain
-            # of deltas isn't tooo long, as otherwise read performance degrades.
-            if prev_group:
-                is_in_db = self.db_pool.simple_select_one_onecol_txn(
-                    txn,
-                    table="state_groups",
-                    keyvalues={"id": prev_group},
-                    retcol="id",
-                    allow_none=True,
-                )
-                if not is_in_db:
-                    raise Exception(
-                        "Trying to persist state with unpersisted prev_group: %r"
-                        % (prev_group,)
-                    )
-
-                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
-            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
-                assert delta_ids is not None
-
-                self.db_pool.simple_insert_txn(
-                    txn,
-                    table="state_group_edges",
-                    values={"state_group": state_group, "prev_state_group": prev_group},
-                )
+            self.db_pool.simple_insert_txn(
+                txn,
+                table="state_group_edges",
+                values={"state_group": state_group, "prev_state_group": prev_group},
+            )
 
-                self.db_pool.simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    keys=("state_group", "room_id", "type", "state_key", "event_id"),
-                    values=[
-                        (state_group, room_id, key[0], key[1], state_id)
-                        for key, state_id in delta_ids.items()
-                    ],
-                )
-            else:
-                self.db_pool.simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    keys=("state_group", "room_id", "type", "state_key", "event_id"),
-                    values=[
-                        (state_group, room_id, key[0], key[1], state_id)
-                        for key, state_id in current_state_ids.items()
-                    ],
-                )
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (state_group, room_id, key[0], key[1], state_id)
+                    for key, state_id in delta_ids.items()
+                ],
+            )
+
+            return state_group
+
+        def insert_full_state_txn(
+            txn: LoggingTransaction, current_state_ids: StateMap[str]
+        ) -> int:
+            """Persist the full state, returning the new state group."""
+            state_group = self._state_group_seq_gen.get_next_id_txn(txn)
+
+            self.db_pool.simple_insert_txn(
+                txn,
+                table="state_groups",
+                values={"id": state_group, "room_id": room_id, "event_id": event_id},
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (state_group, room_id, key[0], key[1], state_id)
+                    for key, state_id in current_state_ids.items()
+                ],
+            )
 
             # Prefill the state group caches with this group.
             # It's fine to use the sequence like this as the state group map
@@ -491,7 +519,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 self._state_group_members_cache.update,
                 self._state_group_members_cache.sequence,
                 key=state_group,
-                value=dict(current_member_state_ids),
+                value=current_member_state_ids,
             )
 
             current_non_member_state_ids = {
@@ -503,13 +531,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 self._state_group_cache.update,
                 self._state_group_cache.sequence,
                 key=state_group,
-                value=dict(current_non_member_state_ids),
+                value=current_non_member_state_ids,
             )
 
             return state_group
 
+        if prev_group is not None:
+            state_group = await self.db_pool.runInteraction(
+                "store_state_group.insert_delta_group",
+                insert_delta_group_txn,
+                prev_group,
+                delta_ids,
+            )
+            if state_group is not None:
+                return state_group
+
+        # We're going to persist the state as a complete group rather than
+        # a delta, so first we need to ensure we have loaded the state map
+        # from the database.
+        if current_state_ids is None:
+            assert prev_group is not None
+            assert delta_ids is not None
+            groups = await self._get_state_for_groups([prev_group])
+            current_state_ids = dict(groups[prev_group])
+            current_state_ids.update(delta_ids)
+
         return await self.db_pool.runInteraction(
-            "store_state_group", _store_state_group_txn
+            "store_state_group.insert_full_state",
+            insert_full_state_txn,
+            current_state_ids,
         )
 
     async def purge_unreferenced_state_groups(