From 070c0279d43b5bbebc640097aec5a48d87d2dec4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 19 May 2022 16:18:30 +0100 Subject: await_lazy_loading --- synapse/events/builder.py | 4 +++- synapse/handlers/sync.py | 8 ++++++-- synapse/state/__init__.py | 35 +++++++++++++++++++++++++++------ synapse/storage/state.py | 49 +++++++++++++++++++++++++++++++++++++---------- synapse/visibility.py | 1 + 5 files changed, 78 insertions(+), 19 deletions(-) (limited to 'synapse') diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 98c203ada0..68ab2113d4 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -120,8 +120,10 @@ class EventBuilder: The signed and hashed event. """ if auth_event_ids is None: + # we pick the auth events based on our best knowledge of the current state + # of the room, so we don't need to await full state. state_ids = await self._state.get_current_state_ids( - self.room_id, prev_event_ids + self.room_id, prev_event_ids, await_full_state=False ) auth_event_ids = self._event_auth_handler.compute_auth_events( self, state_ids diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 59b5d497be..ed0e8a9fe6 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -902,11 +902,15 @@ class SyncHandler: if full_state: if batch: current_state_ids = await self.state_store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + batch.events[-1].event_id, + state_filter=state_filter, + await_full_state=not lazy_load_members, # TODO ) state_ids = await self.state_store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + batch.events[0].event_id, + state_filter=state_filter, + await_full_state=not lazy_load_members, # TODO ) else: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4b4ed42cff..098b5f32ff 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -48,6 +48,7 @@ from synapse.logging.context import ContextResourceUsage from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo +from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -177,7 +178,16 @@ class StateHandler: assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state") - ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) + + filter = StateFilter.all() + if event_type: + filter = StateFilter.from_types(((event_type, state_key),)) + + ret = await self.resolve_state_groups_for_events( + room_id, + latest_event_ids, + await_full_state=filter.must_await_full_state(self.hs.is_mine_id), + ) state = ret.state if event_type: @@ -195,7 +205,10 @@ class StateHandler: } async def get_current_state_ids( - self, room_id: str, latest_event_ids: Optional[Collection[str]] = None + self, + room_id: str, + latest_event_ids: Optional[Collection[str]] = None, + await_full_state: bool = True, ) -> StateMap[str]: """Get the current state, or the state at a set of events, for a room @@ -203,6 +216,8 @@ class StateHandler: room_id: latest_event_ids: if given, the forward extremities to resolve. If None, we look them up from the database (via a cache). + await_full_state: if true, will block if we do not yet have complete + state at the latest events. Returns: the state dict, mapping from (event_type, state_key) -> event_id @@ -212,7 +227,9 @@ class StateHandler: assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) + ret = await self.resolve_state_groups_for_events( + room_id, latest_event_ids, await_full_state=await_full_state + ) return ret.state async def get_current_users_in_room( @@ -323,7 +340,9 @@ class StateHandler: logger.debug("calling resolve_state_groups from compute_event_context") entry = await self.resolve_state_groups_for_events( - event.room_id, event.prev_event_ids() + event.room_id, + event.prev_event_ids(), + await_full_state=False, ) state_ids_before_event = entry.state @@ -404,7 +423,7 @@ class StateHandler: @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: Collection[str] + self, room_id: str, event_ids: Collection[str], await_full_state: bool = True ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -412,13 +431,17 @@ class StateHandler: Args: room_id event_ids + await_full_state: if true, will block if we do not yet have complete + state at these events. Returns: The resolved state """ logger.debug("resolve_state_groups event_ids %s", event_ids) - state_groups = await self.state_store.get_state_group_for_events(event_ids) + state_groups = await self.state_store.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) state_group_ids = state_groups.values() diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e58301a8f0..a7e721aef2 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -609,13 +609,18 @@ class StateGroupStorage: return state_group_delta.prev_group, state_group_delta.delta_ids async def get_state_groups_ids( - self, _room_id: str, event_ids: Collection[str] + self, + _room_id: str, + event_ids: Collection[str], + await_full_state: bool = True, ) -> Dict[int, MutableStateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: _room_id: id of the room for these events event_ids: ids of the events + await_full_state: if true, will block if we do not yet have complete + state at these events. Returns: dict of state_group_id -> (dict of (type, state_key) -> event id) @@ -627,7 +632,9 @@ class StateGroupStorage: if not event_ids: return {} - event_to_groups = await self.get_state_group_for_events(event_ids) + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups(groups) @@ -700,7 +707,10 @@ class StateGroupStorage: return self.stores.state._get_state_groups_from_groups(groups, state_filter) async def get_state_for_events( - self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None + self, + event_ids: Collection[str], + state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -708,6 +718,8 @@ class StateGroupStorage: Args: event_ids: The events to fetch the state of. state_filter: The state filter used to fetch state. + await_full_state: if true, will block if the state_filter includes state + which is not yet complete. Returns: A dict of (event_id) -> (type, state_key) -> [state_events] @@ -716,8 +728,11 @@ class StateGroupStorage: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + if ( + await_full_state + and state_filter + and not state_filter.must_await_full_state(self._is_mine_id) + ): await_full_state = False event_to_groups = await self.get_state_group_for_events( @@ -749,6 +764,7 @@ class StateGroupStorage: self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids @@ -757,6 +773,8 @@ class StateGroupStorage: Args: event_ids: events whose state should be returned state_filter: The state filter used to fetch state from the database. + await_full_state: if true, will block if the state_filter includes state + which is not yet complete. Returns: A dict from event_id -> (type, state_key) -> event_id @@ -765,8 +783,12 @@ class StateGroupStorage: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + + if ( + await_full_state + and state_filter + and not state_filter.must_await_full_state(self._is_mine_id) + ): await_full_state = False event_to_groups = await self.get_state_group_for_events( @@ -808,7 +830,10 @@ class StateGroupStorage: return state_map[event_id] async def get_state_ids_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None + self, + event_id: str, + state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """ Get the state dict corresponding to a particular event @@ -816,6 +841,8 @@ class StateGroupStorage: Args: event_id: event whose state should be returned state_filter: The state filter used to fetch state from the database. + await_full_state: if true, will block if the state_filter includes state + which is not yet complete. Returns: A dict from (type, state_key) -> state_event_id @@ -825,7 +852,9 @@ class StateGroupStorage: outlier or is unknown) """ state_map = await self.get_state_ids_for_events( - [event_id], state_filter or StateFilter.all() + [event_id], + state_filter or StateFilter.all(), + await_full_state=await_full_state, ) return state_map[event_id] @@ -857,7 +886,7 @@ class StateGroupStorage: Args: event_ids: events to get state groups for await_full_state: if true, will block if we do not yet have complete - state at these events. + state at these event. """ if await_full_state: await self._partial_state_events_tracker.await_full_state(event_ids) diff --git a/synapse/visibility.py b/synapse/visibility.py index de6d2ffc52..b851e660fd 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -85,6 +85,7 @@ async def filter_events_for_client( event_id_to_state = await storage.state.get_state_for_events( frozenset(e.event_id for e in events if not e.internal_metadata.outlier), state_filter=StateFilter.from_types(types), + await_full_state=False, ) # Get the users who are ignored by the requesting user. -- cgit 1.4.1