summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2022-04-21 07:42:03 +0100
committerGitHub <noreply@github.com>2022-04-21 07:42:03 +0100
commitf5668f0b4a6cca659ae98d3cb3714692ba488e89 (patch)
treecc1e5e8ff7e8190aa5d55666054bb4ba929077ab /synapse/storage
parentRemove leftover references to setup.py (#12514) (diff)
downloadsynapse-f5668f0b4a6cca659ae98d3cb3714692ba488e89.tar.xz
Await un-partial-stating after a partial-state join (#12399)
When we join a room via the faster-joins mechanism, we end up with "partial
state" at some points on the event DAG. Many parts of the codebase need to
wait for the full state to load. So, we implement a mechanism to keep track of
which events have partial state, and wait for them to be fully-populated.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/events_worker.py10
-rw-r--r--synapse/storage/databases/main/state.py1
-rw-r--r--synapse/storage/state.py28
-rw-r--r--synapse/storage/util/partial_state_events_tracker.py120
4 files changed, 155 insertions, 4 deletions
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 60876204bd..6d6e146ff1 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1974,7 +1974,15 @@ class EventsWorkerStore(SQLBaseStore):
     async def get_partial_state_events(
         self, event_ids: Collection[str]
     ) -> Dict[str, bool]:
-        """Checks which of the given events have partial state"""
+        """Checks which of the given events have partial state
+
+        Args:
+            event_ids: the events we want to check for partial state.
+
+        Returns:
+            a dict mapping from event id to partial-stateness. We return True for
+            any of the events which are unknown (or are outliers).
+        """
         result = await self.db_pool.simple_select_many_batch(
             table="partial_state_events",
             column="event_id",
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 7a1b013fa3..e653841fe5 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -396,6 +396,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         # TODO(faster_joins): need to do something about workers here
+        txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,))
         txn.call_after(
             self._get_state_group_for_event.prefill,
             (event.event_id,),
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cda194e8c8..d1d5859214 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -31,6 +31,7 @@ from frozendict import frozendict
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
+from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
 from synapse.types import MutableStateMap, StateKey, StateMap
 
 if TYPE_CHECKING:
@@ -542,6 +543,10 @@ class StateGroupStorage:
 
     def __init__(self, hs: "HomeServer", stores: "Databases"):
         self.stores = stores
+        self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
+
+    def notify_event_un_partial_stated(self, event_id: str) -> None:
+        self._partial_state_events_tracker.notify_un_partial_stated(event_id)
 
     async def get_state_group_delta(
         self, state_group: int
@@ -579,7 +584,7 @@ class StateGroupStorage:
         if not event_ids:
             return {}
 
-        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -668,7 +673,7 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                (ie they are outliers or unknown)
         """
-        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
@@ -709,7 +714,7 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                 (ie they are outliers or unknown)
         """
-        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
@@ -785,6 +790,23 @@ class StateGroupStorage:
             groups, state_filter or StateFilter.all()
         )
 
+    async def _get_state_group_for_events(
+        self,
+        event_ids: Collection[str],
+        await_full_state: bool = True,
+    ) -> Mapping[str, int]:
+        """Returns mapping event_id -> state_group
+
+        Args:
+            event_ids: events to get state groups for
+            await_full_state: if true, will block if we do not yet have complete
+               state at this event.
+        """
+        if await_full_state:
+            await self._partial_state_events_tracker.await_full_state(event_ids)
+
+        return await self.stores.main._get_state_group_for_events(event_ids)
+
     async def store_state_group(
         self,
         event_id: str,
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
new file mode 100644
index 0000000000..a61a951ef0
--- /dev/null
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -0,0 +1,120 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections import defaultdict
+from typing import Collection, Dict, Set
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util import unwrapFirstError
+
+logger = logging.getLogger(__name__)
+
+
+class PartialStateEventsTracker:
+    """Keeps track of which events have partial state, after a partial-state join"""
+
+    def __init__(self, store: EventsWorkerStore):
+        self._store = store
+        # a map from event id to a set of Deferreds which are waiting for that event to be
+        # un-partial-stated.
+        self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set)
+
+    def notify_un_partial_stated(self, event_id: str) -> None:
+        """Notify that we now have full state for a given event
+
+        Called by the state-resynchronization loop whenever we resynchronize the state
+        for a particular event. Unblocks any callers to await_full_state() for that
+        event.
+
+        Args:
+            event_id: the event that now has full state.
+        """
+        observers = self._observers.pop(event_id, None)
+        if not observers:
+            return
+        logger.info(
+            "Notifying %i things waiting for un-partial-stating of event %s",
+            len(observers),
+            event_id,
+        )
+        with PreserveLoggingContext():
+            for o in observers:
+                o.callback(None)
+
+    async def await_full_state(self, event_ids: Collection[str]) -> None:
+        """Wait for all the given events to have full state.
+
+        Args:
+            event_ids: the list of event ids that we want full state for
+        """
+        # first try the happy path: if there are no partial-state events, we can return
+        # quickly
+        partial_state_event_ids = [
+            ev
+            for ev, p in (await self._store.get_partial_state_events(event_ids)).items()
+            if p
+        ]
+
+        if not partial_state_event_ids:
+            return
+
+        logger.info(
+            "Awaiting un-partial-stating of events %s",
+            partial_state_event_ids,
+            stack_info=True,
+        )
+
+        # create an observer for each lazy-joined event
+        observers: Dict[str, Deferred[None]] = {
+            event_id: Deferred() for event_id in partial_state_event_ids
+        }
+        for event_id, observer in observers.items():
+            self._observers[event_id].add(observer)
+
+        try:
+            # some of them may have been un-lazy-joined between us checking the db and
+            # registering the observer, in which case we'd wait forever for the
+            # notification. Call back the observers now.
+            for event_id, partial in (
+                await self._store.get_partial_state_events(observers.keys())
+            ).items():
+                # there may have been a call to notify_un_partial_stated during the
+                # db query, so the observers may already have been called.
+                if not partial and not observers[event_id].called:
+                    observers[event_id].callback(None)
+
+            await make_deferred_yieldable(
+                defer.gatherResults(
+                    observers.values(),
+                    consumeErrors=True,
+                )
+            ).addErrback(unwrapFirstError)
+            logger.info("Events %s all un-partial-stated", observers.keys())
+        finally:
+            # remove any observers we created. This should happen when the notification
+            # is received, but that might not happen for two reasons:
+            #   (a) we're bailing out early on an exception (including us being
+            #       cancelled during the await)
+            #   (b) the event got de-lazy-joined before we set up the observer.
+            for event_id, observer in observers.items():
+                observer_set = self._observers.get(event_id)
+                if observer_set:
+                    observer_set.discard(observer)
+                    if not observer_set:
+                        del self._observers[event_id]