diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py
index 2d6f80f770..47cec8c469 100644
--- a/synapse/storage/controllers/purge_events.py
+++ b/synapse/storage/controllers/purge_events.py
@@ -128,6 +128,16 @@ class PurgeEventsStorageController:
next_to_search |= prevs
state_groups_seen |= prevs
+ # We also check to see if anything referencing the state groups are
+ # also unreferenced. This helps ensure that we delete unreferenced
+ # state groups, if we don't then we will de-delta them when we
+ # delete the other state groups leading to increased DB usage.
+ next_edges = await self.stores.state.get_next_state_groups(current_search)
+ nexts = set(next_edges.keys())
+ nexts -= state_groups_seen
+ next_to_search |= nexts
+ state_groups_seen |= nexts
+
to_delete = state_groups_seen - referenced_groups
return to_delete
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0f47642ae5..8c7980e719 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -853,7 +853,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
List[Tuple[int, int]],
await self.db_pool.simple_select_many_batch(
table="state_group_edges",
- column="prev_state_group",
+ column="state_group",
iterable=state_groups,
keyvalues={},
retcols=("state_group", "prev_state_group"),
@@ -863,6 +863,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return dict(rows)
+ @trace
+ @tag_args
+ async def get_next_state_groups(
+ self, state_groups: Iterable[int]
+ ) -> Dict[int, int]:
+ """Fetch the groups that have the given state groups as their previous
+ state groups.
+
+ Args:
+ state_groups
+
+ Returns:
+ A mapping from state group to previous state group.
+ """
+
+ rows = cast(
+ List[Tuple[int, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("state_group", "prev_state_group"),
+ desc="get_next_state_groups",
+ ),
+ )
+
+ return dict(rows)
+
async def purge_room_state(self, room_id: str) -> None:
return await self.db_pool.runInteraction(
"purge_room_state",
|