diff --git a/changelog.d/18131.bugfix b/changelog.d/18131.bugfix
new file mode 100644
index 0000000000..4d0c19fab9
--- /dev/null
+++ b/changelog.d/18131.bugfix
@@ -0,0 +1 @@
+Fix rare edge case where state groups could be deleted while we are persisting new events that reference them.
diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py
index 2d6f80f770..47cec8c469 100644
--- a/synapse/storage/controllers/purge_events.py
+++ b/synapse/storage/controllers/purge_events.py
@@ -128,6 +128,16 @@ class PurgeEventsStorageController:
next_to_search |= prevs
state_groups_seen |= prevs
+ # We also check to see if anything referencing the state groups are
+ # also unreferenced. This helps ensure that we delete unreferenced
+ # state groups, if we don't then we will de-delta them when we
+ # delete the other state groups leading to increased DB usage.
+ next_edges = await self.stores.state.get_next_state_groups(current_search)
+ nexts = set(next_edges.keys())
+ nexts -= state_groups_seen
+ next_to_search |= nexts
+ state_groups_seen |= nexts
+
to_delete = state_groups_seen - referenced_groups
return to_delete
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0f47642ae5..8c7980e719 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -853,7 +853,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
List[Tuple[int, int]],
await self.db_pool.simple_select_many_batch(
table="state_group_edges",
- column="prev_state_group",
+ column="state_group",
iterable=state_groups,
keyvalues={},
retcols=("state_group", "prev_state_group"),
@@ -863,6 +863,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return dict(rows)
+ @trace
+ @tag_args
+ async def get_next_state_groups(
+ self, state_groups: Iterable[int]
+ ) -> Dict[int, int]:
+ """Fetch the groups that have the given state groups as their previous
+ state groups.
+
+ Args:
+ state_groups
+
+ Returns:
+ A mapping from state group to previous state group.
+ """
+
+ rows = cast(
+ List[Tuple[int, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("state_group", "prev_state_group"),
+ desc="get_next_state_groups",
+ ),
+ )
+
+ return dict(rows)
+
async def purge_room_state(self, room_id: str) -> None:
return await self.db_pool.runInteraction(
"purge_room_state",
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index efd8d25bd1..5d6a8518c0 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -195,3 +195,78 @@ class PurgeTests(HomeserverTestCase):
self.assertEqual(second_state, {})
self.assertEqual(third_state, {})
self.assertNotEqual(last_state, {})
+
+ def test_purge_unreferenced_state_group(self) -> None:
+ """Test that purging a room also gets rid of unreferenced state groups
+ it encounters during the purge.
+
+ This is important, as otherwise these unreferenced state groups get
+ "de-deltaed" during the purge process, consuming lots of disk space.
+ """
+
+ self.helper.send(self.room_id, body="test1")
+ state1 = self.helper.send_state(
+ self.room_id, "org.matrix.test", body={"number": 2}
+ )
+ state2 = self.helper.send_state(
+ self.room_id, "org.matrix.test", body={"number": 3}
+ )
+ self.helper.send(self.room_id, body="test4")
+ last = self.helper.send(self.room_id, body="test5")
+
+ # Create an unreferenced state group that has a prev group of one of the
+ # to-be-purged events.
+ prev_group = self.get_success(
+ self.store._get_state_group_for_event(state1["event_id"])
+ )
+ unreferenced_state_group = self.get_success(
+ self.state_store.store_state_group(
+ event_id=last["event_id"],
+ room_id=self.room_id,
+ prev_group=prev_group,
+ delta_ids={("org.matrix.test", ""): state2["event_id"]},
+ current_state_ids=None,
+ )
+ )
+
+ # Get the topological token
+ token = self.get_success(
+ self.store.get_topological_token_for_event(last["event_id"])
+ )
+ token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
+
+ # Purge everything before this topological token
+ self.get_success(
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, token_str, True
+ )
+ )
+
+ # Advance so that the background jobs to delete the state groups runs
+ self.reactor.advance(
+ 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+ )
+
+ # We expect that the unreferenced state group has been deleted.
+ row = self.get_success(
+ self.state_store.db_pool.simple_select_one_onecol(
+ table="state_groups",
+ keyvalues={"id": unreferenced_state_group},
+ retcol="id",
+ allow_none=True,
+ desc="test_purge_unreferenced_state_group",
+ )
+ )
+ self.assertIsNone(row)
+
+ # We expect there to now only be one state group for the room, which is
+ # the state group of the last event (as the only outlier).
+ state_groups = self.get_success(
+ self.state_store.db_pool.simple_select_onecol(
+ table="state_groups",
+ keyvalues={"room_id": self.room_id},
+ retcol="id",
+ desc="test_purge_unreferenced_state_group",
+ )
+ )
+ self.assertEqual(len(state_groups), 1)
|