diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cda194e8c8..d1d5859214 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -31,6 +31,7 @@ from frozendict import frozendict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
+from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
from synapse.types import MutableStateMap, StateKey, StateMap
if TYPE_CHECKING:
@@ -542,6 +543,10 @@ class StateGroupStorage:
def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores
+ self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
+
+ def notify_event_un_partial_stated(self, event_id: str) -> None:
+ self._partial_state_events_tracker.notify_un_partial_stated(event_id)
async def get_state_group_delta(
self, state_group: int
@@ -579,7 +584,7 @@ class StateGroupStorage:
if not event_ids:
return {}
- event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
+ event_to_groups = await self._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -668,7 +673,7 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
+ event_to_groups = await self._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
@@ -709,7 +714,7 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
+ event_to_groups = await self._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
@@ -785,6 +790,23 @@ class StateGroupStorage:
groups, state_filter or StateFilter.all()
)
+ async def _get_state_group_for_events(
+ self,
+ event_ids: Collection[str],
+ await_full_state: bool = True,
+ ) -> Mapping[str, int]:
+ """Returns mapping event_id -> state_group
+
+ Args:
+ event_ids: events to get state groups for
+ await_full_state: if true, will block if we do not yet have complete
+ state at this event.
+ """
+ if await_full_state:
+ await self._partial_state_events_tracker.await_full_state(event_ids)
+
+ return await self.stores.main._get_state_group_for_events(event_ids)
+
async def store_state_group(
self,
event_id: str,
|