summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2022-04-21 07:42:03 +0100
committerGitHub <noreply@github.com>2022-04-21 07:42:03 +0100
commitf5668f0b4a6cca659ae98d3cb3714692ba488e89 (patch)
treecc1e5e8ff7e8190aa5d55666054bb4ba929077ab /synapse/storage/state.py
parentRemove leftover references to setup.py (#12514) (diff)
downloadsynapse-f5668f0b4a6cca659ae98d3cb3714692ba488e89.tar.xz
Await un-partial-stating after a partial-state join (#12399)
When we join a room via the faster-joins mechanism, we end up with "partial
state" at some points on the event DAG. Many parts of the codebase need to
wait for the full state to load. So, we implement a mechanism to keep track of
which events have partial state, and wait for them to be fully-populated.
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py28
1 files changed, 25 insertions, 3 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cda194e8c8..d1d5859214 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -31,6 +31,7 @@ from frozendict import frozendict
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
+from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
 from synapse.types import MutableStateMap, StateKey, StateMap
 
 if TYPE_CHECKING:
@@ -542,6 +543,10 @@ class StateGroupStorage:
 
     def __init__(self, hs: "HomeServer", stores: "Databases"):
         self.stores = stores
+        self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
+
+    def notify_event_un_partial_stated(self, event_id: str) -> None:
+        self._partial_state_events_tracker.notify_un_partial_stated(event_id)
 
     async def get_state_group_delta(
         self, state_group: int
@@ -579,7 +584,7 @@ class StateGroupStorage:
         if not event_ids:
             return {}
 
-        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -668,7 +673,7 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                (ie they are outliers or unknown)
         """
-        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
@@ -709,7 +714,7 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                 (ie they are outliers or unknown)
         """
-        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
@@ -785,6 +790,23 @@ class StateGroupStorage:
             groups, state_filter or StateFilter.all()
         )
 
+    async def _get_state_group_for_events(
+        self,
+        event_ids: Collection[str],
+        await_full_state: bool = True,
+    ) -> Mapping[str, int]:
+        """Returns mapping event_id -> state_group
+
+        Args:
+            event_ids: events to get state groups for
+            await_full_state: if true, will block if we do not yet have complete
+               state at this event.
+        """
+        if await_full_state:
+            await self._partial_state_events_tracker.await_full_state(event_ids)
+
+        return await self.stores.main._get_state_group_for_events(event_ids)
+
     async def store_state_group(
         self,
         event_id: str,