diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/state.py | 1 | ||||
-rw-r--r-- | synapse/storage/state.py | 28 | ||||
-rw-r--r-- | synapse/storage/util/partial_state_events_tracker.py | 120 |
4 files changed, 155 insertions, 4 deletions
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 60876204bd..6d6e146ff1 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1974,7 +1974,15 @@ class EventsWorkerStore(SQLBaseStore): async def get_partial_state_events( self, event_ids: Collection[str] ) -> Dict[str, bool]: - """Checks which of the given events have partial state""" + """Checks which of the given events have partial state + + Args: + event_ids: the events we want to check for partial state. + + Returns: + a dict mapping from event id to partial-stateness. We return True for + any of the events which are unknown (or are outliers). + """ result = await self.db_pool.simple_select_many_batch( table="partial_state_events", column="event_id", diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 7a1b013fa3..e653841fe5 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -396,6 +396,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # TODO(faster_joins): need to do something about workers here + txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,)) txn.call_after( self._get_state_group_for_event.prefill, (event.event_id,), diff --git a/synapse/storage/state.py b/synapse/storage/state.py index cda194e8c8..d1d5859214 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -31,6 +31,7 @@ from frozendict import frozendict from synapse.api.constants import EventTypes from synapse.events import EventBase +from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker from synapse.types import MutableStateMap, StateKey, StateMap if TYPE_CHECKING: @@ -542,6 +543,10 @@ class StateGroupStorage: def __init__(self, hs: "HomeServer", stores: "Databases"): self.stores = stores + self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) + + def notify_event_un_partial_stated(self, event_id: str) -> None: + self._partial_state_events_tracker.notify_un_partial_stated(event_id) async def get_state_group_delta( self, state_group: int @@ -579,7 +584,7 @@ class StateGroupStorage: if not event_ids: return {} - event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups(groups) @@ -668,7 +673,7 @@ class StateGroupStorage: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( @@ -709,7 +714,7 @@ class StateGroupStorage: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( @@ -785,6 +790,23 @@ class StateGroupStorage: groups, state_filter or StateFilter.all() ) + async def _get_state_group_for_events( + self, + event_ids: Collection[str], + await_full_state: bool = True, + ) -> Mapping[str, int]: + """Returns mapping event_id -> state_group + + Args: + event_ids: events to get state groups for + await_full_state: if true, will block if we do not yet have complete + state at this event. + """ + if await_full_state: + await self._partial_state_events_tracker.await_full_state(event_ids) + + return await self.stores.main._get_state_group_for_events(event_ids) + async def store_state_group( self, event_id: str, diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py new file mode 100644 index 0000000000..a61a951ef0 --- /dev/null +++ b/synapse/storage/util/partial_state_events_tracker.py @@ -0,0 +1,120 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import defaultdict +from typing import Collection, Dict, Set + +from twisted.internet import defer +from twisted.internet.defer import Deferred + +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util import unwrapFirstError + +logger = logging.getLogger(__name__) + + +class PartialStateEventsTracker: + """Keeps track of which events have partial state, after a partial-state join""" + + def __init__(self, store: EventsWorkerStore): + self._store = store + # a map from event id to a set of Deferreds which are waiting for that event to be + # un-partial-stated. + self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set) + + def notify_un_partial_stated(self, event_id: str) -> None: + """Notify that we now have full state for a given event + + Called by the state-resynchronization loop whenever we resynchronize the state + for a particular event. Unblocks any callers to await_full_state() for that + event. + + Args: + event_id: the event that now has full state. + """ + observers = self._observers.pop(event_id, None) + if not observers: + return + logger.info( + "Notifying %i things waiting for un-partial-stating of event %s", + len(observers), + event_id, + ) + with PreserveLoggingContext(): + for o in observers: + o.callback(None) + + async def await_full_state(self, event_ids: Collection[str]) -> None: + """Wait for all the given events to have full state. + + Args: + event_ids: the list of event ids that we want full state for + """ + # first try the happy path: if there are no partial-state events, we can return + # quickly + partial_state_event_ids = [ + ev + for ev, p in (await self._store.get_partial_state_events(event_ids)).items() + if p + ] + + if not partial_state_event_ids: + return + + logger.info( + "Awaiting un-partial-stating of events %s", + partial_state_event_ids, + stack_info=True, + ) + + # create an observer for each lazy-joined event + observers: Dict[str, Deferred[None]] = { + event_id: Deferred() for event_id in partial_state_event_ids + } + for event_id, observer in observers.items(): + self._observers[event_id].add(observer) + + try: + # some of them may have been un-lazy-joined between us checking the db and + # registering the observer, in which case we'd wait forever for the + # notification. Call back the observers now. + for event_id, partial in ( + await self._store.get_partial_state_events(observers.keys()) + ).items(): + # there may have been a call to notify_un_partial_stated during the + # db query, so the observers may already have been called. + if not partial and not observers[event_id].called: + observers[event_id].callback(None) + + await make_deferred_yieldable( + defer.gatherResults( + observers.values(), + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) + logger.info("Events %s all un-partial-stated", observers.keys()) + finally: + # remove any observers we created. This should happen when the notification + # is received, but that might not happen for two reasons: + # (a) we're bailing out early on an exception (including us being + # cancelled during the await) + # (b) the event got de-lazy-joined before we set up the observer. + for event_id, observer in observers.items(): + observer_set = self._observers.get(event_id) + if observer_set: + observer_set.discard(observer) + if not observer_set: + del self._observers[event_id] |