diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0f47642ae5..8c7980e719 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -853,7 +853,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
List[Tuple[int, int]],
await self.db_pool.simple_select_many_batch(
table="state_group_edges",
- column="prev_state_group",
+ column="state_group",
iterable=state_groups,
keyvalues={},
retcols=("state_group", "prev_state_group"),
@@ -863,6 +863,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return dict(rows)
+ @trace
+ @tag_args
+ async def get_next_state_groups(
+ self, state_groups: Iterable[int]
+ ) -> Dict[int, int]:
+ """Fetch the groups that have the given state groups as their previous
+ state groups.
+
+ Args:
+ state_groups
+
+ Returns:
+ A mapping from state group to previous state group.
+ """
+
+ rows = cast(
+ List[Tuple[int, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("state_group", "prev_state_group"),
+ desc="get_next_state_groups",
+ ),
+ )
+
+ return dict(rows)
+
async def purge_room_state(self, room_id: str) -> None:
return await self.db_pool.runInteraction(
"purge_room_state",
|