From 7a199951202f53cef398507439bde306e4833219 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 8 Aug 2022 16:59:56 +0100 Subject: Correct a misnamed argument in state res v2 (#13467) In state res v2, we apply two passes of iterative auth checks. The first pass replays power events and events in their auth chains, but only those belonging to the full conflicted set. The source code as written suggests that we want only those belonging to the auth difference (which is a smaller set of events). At runtime we were doing the correct thing anyway, because the only callsite of `_reverse_topological_power_sort` passes in the `full_conflicted_set`. So this really is just a rename. --- synapse/state/v2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'synapse/state') diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 7db032203b..cf3045f82e 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -434,7 +434,7 @@ async def _add_event_and_auth_chain_to_graph( event_id: str, event_map: Dict[str, EventBase], state_res_store: StateResolutionStore, - auth_diff: Set[str], + full_conflicted_set: Set[str], ) -> None: """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -445,7 +445,7 @@ async def _add_event_and_auth_chain_to_graph( event_id: Event to add to the graph event_map state_res_store - auth_diff: Set of event IDs that are in the auth difference. + full_conflicted_set: Set of event IDs that are in the full conflicted set. """ state = [event_id] @@ -455,7 +455,7 @@ async def _add_event_and_auth_chain_to_graph( event = await _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): - if aid in auth_diff: + if aid in full_conflicted_set: if aid not in graph: state.append(aid) @@ -468,7 +468,7 @@ async def _reverse_topological_power_sort( event_ids: Iterable[str], event_map: Dict[str, EventBase], state_res_store: StateResolutionStore, - auth_diff: Set[str], + full_conflicted_set: Set[str], ) -> List[str]: """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts @@ -479,7 +479,7 @@ async def _reverse_topological_power_sort( event_ids: The events to sort event_map state_res_store - auth_diff: Set of event IDs that are in the auth difference. + full_conflicted_set: Set of event IDs that are in the full conflicted set. Returns: The sorted list @@ -488,7 +488,7 @@ async def _reverse_topological_power_sort( graph: Dict[str, Set[str]] = {} for idx, event_id in enumerate(event_ids, start=1): await _add_event_and_auth_chain_to_graph( - graph, room_id, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, full_conflicted_set ) # We await occasionally when we're working with large data sets to -- cgit 1.5.1 From 5e7847dc923142bc68834f9b9538ada3fdd887d5 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 23 Aug 2022 10:49:59 +0100 Subject: Cache user IDs instead of profile objects (#13573) The profile objects are never used and increase cache size significantly. --- changelog.d/13573.misc | 1 + synapse/handlers/sync.py | 4 +- synapse/state/__init__.py | 13 +++--- synapse/storage/databases/main/roommember.py | 67 ++++++++++++---------------- synapse/util/caches/descriptors.py | 26 ++++++++--- 5 files changed, 57 insertions(+), 54 deletions(-) create mode 100644 changelog.d/13573.misc (limited to 'synapse/state') diff --git a/changelog.d/13573.misc b/changelog.d/13573.misc new file mode 100644 index 0000000000..1ce9c0c081 --- /dev/null +++ b/changelog.d/13573.misc @@ -0,0 +1 @@ +Cache user IDs instead of profiles to reduce cache memory usage. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b4d3f3958c..2d95b1fa24 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -2421,10 +2421,10 @@ class SyncHandler: joined_room.room_id, joined_room.event_pos.stream ) ) - users_in_room = await self.state.get_current_users_in_room( + user_ids_in_room = await self.state.get_current_user_ids_in_room( joined_room.room_id, extrems ) - if user_id in users_in_room: + if user_id in user_ids_in_room: joined_room_ids.add(joined_room.room_id) return frozenset(joined_room_ids) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index c355e4f98a..3047e1b1ad 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -44,7 +44,6 @@ from synapse.logging.context import ContextResourceUsage from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.roommember import ProfileInfo from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util.async_helpers import Linearizer @@ -210,11 +209,11 @@ class StateHandler: ret = await self.resolve_state_groups_for_events(room_id, event_ids) return await ret.get_state(self._state_storage_controller, state_filter) - async def get_current_users_in_room( + async def get_current_user_ids_in_room( self, room_id: str, latest_event_ids: List[str] - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: """ - Get the users who are currently in a room. + Get the users IDs who are currently in a room. Note: This is much slower than using the equivalent method `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`, @@ -225,15 +224,15 @@ class StateHandler: room_id: The ID of the room. latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. Returns: - Dictionary of user IDs to their profileinfo. + Set of user IDs in the room. """ assert latest_event_ids is not None - logger.debug("calling resolve_state_groups from get_current_users_in_room") + logger.debug("calling resolve_state_groups from get_current_user_ids_in_room") entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = await entry.get_state(self._state_storage_controller, StateFilter.all()) - return await self.store.get_joined_users_from_state(room_id, state, entry) + return await self.store.get_joined_user_ids_from_state(room_id, state, entry) async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 827c1f1efd..0eb024a809 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -835,9 +835,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): return shared_room_ids or frozenset() - async def get_joined_users_from_state( + async def get_joined_user_ids_from_state( self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry" - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: state_group: Union[object, int] = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -848,25 +848,25 @@ class RoomMemberWorkerStore(EventsWorkerStore): assert state_group is not None with Measure(self._clock, "get_joined_users_from_state"): - return await self._get_joined_users_from_context( + return await self._get_joined_user_ids_from_context( room_id, state_group, state, context=state_entry ) @cached(num_args=2, iterable=True, max_entries=100000) - async def _get_joined_users_from_context( + async def _get_joined_user_ids_from_context( self, room_id: str, state_group: Union[object, int], current_state_ids: StateMap[str], event: Optional[EventBase] = None, context: Optional["_StateCacheEntry"] = None, - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. assert state_group is not None - users_in_room = {} + users_in_room = set() member_event_ids = [ e_id for key, e_id in current_state_ids.items() @@ -879,11 +879,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): # If we do then we can reuse that result and simply update it with # any membership changes in `delta_ids` if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get_immediate( + prev_res = self._get_joined_user_ids_from_context.cache.get_immediate( (room_id, context.prev_group), None ) - if prev_res and isinstance(prev_res, dict): - users_in_room = dict(prev_res) + if prev_res and isinstance(prev_res, set): + users_in_room = prev_res member_event_ids = [ e_id for key, e_id in context.delta_ids.items() @@ -891,7 +891,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ] for etype, state_key in context.delta_ids: if etype == EventTypes.Member: - users_in_room.pop(state_key, None) + users_in_room.discard(state_key) # We check if we have any of the member event ids in the event cache # before we ask the DB @@ -908,42 +908,41 @@ class RoomMemberWorkerStore(EventsWorkerStore): ev_entry = event_map.get(event_id) if ev_entry and not ev_entry.event.rejected_reason: if ev_entry.event.membership == Membership.JOIN: - users_in_room[ev_entry.event.state_key] = ProfileInfo( - display_name=ev_entry.event.content.get("displayname", None), - avatar_url=ev_entry.event.content.get("avatar_url", None), - ) + users_in_room.add(ev_entry.event.state_key) else: missing_member_event_ids.append(event_id) if missing_member_event_ids: - event_to_memberships = await self._get_joined_profiles_from_event_ids( + event_to_memberships = await self._get_user_ids_from_membership_event_ids( missing_member_event_ids ) - users_in_room.update(row for row in event_to_memberships.values() if row) + users_in_room.update(event_to_memberships.values()) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: if event.event_id in member_event_ids: - users_in_room[event.state_key] = ProfileInfo( - display_name=event.content.get("displayname", None), - avatar_url=event.content.get("avatar_url", None), - ) + users_in_room.add(event.state_key) return users_in_room - @cached(max_entries=10000) - def _get_joined_profile_from_event_id( + @cached( + max_entries=10000, + # This name matches the old function that has been replaced - the cache name + # is kept here to maintain backwards compatibility. + name="_get_joined_profile_from_event_id", + ) + def _get_user_id_from_membership_event_id( self, event_id: str ) -> Optional[Tuple[str, ProfileInfo]]: raise NotImplementedError() @cachedList( - cached_method_name="_get_joined_profile_from_event_id", + cached_method_name="_get_user_id_from_membership_event_id", list_name="event_ids", ) - async def _get_joined_profiles_from_event_ids( + async def _get_user_ids_from_membership_event_ids( self, event_ids: Iterable[str] - ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]: + ) -> Dict[str, str]: """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. @@ -958,21 +957,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): table="room_memberships", column="event_id", iterable=event_ids, - retcols=("user_id", "display_name", "avatar_url", "event_id"), + retcols=("user_id", "event_id"), keyvalues={"membership": Membership.JOIN}, batch_size=1000, - desc="_get_joined_profiles_from_event_ids", + desc="_get_user_ids_from_membership_event_ids", ) - return { - row["event_id"]: ( - row["user_id"], - ProfileInfo( - avatar_url=row["avatar_url"], display_name=row["display_name"] - ), - ) - for row in rows - } + return {row["event_id"]: row["user_id"] for row in rows} @cached(max_entries=10000) async def is_host_joined(self, room_id: str, host: str) -> bool: @@ -1131,12 +1122,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): else: # The cache doesn't match the state group or prev state group, # so we calculate the result from first principles. - joined_users = await self.get_joined_users_from_state( + joined_user_ids = await self.get_joined_user_ids_from_state( room_id, state, state_entry ) cache.hosts_to_joined_users = {} - for user_id in joined_users: + for user_id in joined_user_ids: host = intern_string(get_domain_from_id(user_id)) cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 867f315b2a..9d4bc89edb 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -73,8 +73,10 @@ class _CacheDescriptorBase: num_args: Optional[int], uncached_args: Optional[Collection[str]] = None, cache_context: bool = False, + name: Optional[str] = None, ): self.orig = orig + self.name = name or orig.__name__ arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args @@ -211,7 +213,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: LruCache[CacheKey, Any] = LruCache( - cache_name=self.orig.__name__, + cache_name=self.name, max_size=self.max_entries, ) @@ -241,7 +243,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): wrapped = cast(_CachedFunction, _wrapped) wrapped.cache = cache - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -301,12 +303,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, + name: Optional[str] = None, ): super().__init__( orig, num_args=num_args, uncached_args=uncached_args, cache_context=cache_context, + name=name, ) if tree and self.num_args < 2: @@ -321,7 +325,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: DeferredCache[CacheKey, Any] = DeferredCache( - name=self.orig.__name__, + name=self.name, max_entries=self.max_entries, tree=self.tree, iterable=self.iterable, @@ -372,7 +376,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): wrapped.cache = cache wrapped.num_args = self.num_args - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -393,6 +397,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): cached_method_name: str, list_name: str, num_args: Optional[int] = None, + name: Optional[str] = None, ): """ Args: @@ -403,7 +408,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): but including list_name) to use as cache keys. Defaults to all named args of the function. """ - super().__init__(orig, num_args=num_args, uncached_args=None) + super().__init__(orig, num_args=num_args, uncached_args=None, name=name) self.list_name = list_name @@ -525,7 +530,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): else: return defer.succeed(results) - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -577,6 +582,7 @@ def cached( cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, + name: Optional[str] = None, ) -> Callable[[F], _CachedFunction[F]]: func = lambda orig: DeferredCacheDescriptor( orig, @@ -587,13 +593,18 @@ def cached( cache_context=cache_context, iterable=iterable, prune_unread_entries=prune_unread_entries, + name=name, ) return cast(Callable[[F], _CachedFunction[F]], func) def cachedList( - *, cached_method_name: str, list_name: str, num_args: Optional[int] = None + *, + cached_method_name: str, + list_name: str, + num_args: Optional[int] = None, + name: Optional[str] = None, ) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. @@ -628,6 +639,7 @@ def cachedList( cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, + name=name, ) return cast(Callable[[F], _CachedFunction[F]], func) -- cgit 1.5.1 From c406d50d2df3c04e695b826e11c79b3d6326b5ec Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 24 Aug 2022 21:06:31 +0100 Subject: Rename `event_map` to `unpersisted_events` (#13603) --- changelog.d/13603.misc | 1 + synapse/state/v2.py | 69 +++++++++++++++++++++++++++----------------------- 2 files changed, 38 insertions(+), 32 deletions(-) create mode 100644 changelog.d/13603.misc (limited to 'synapse/state') diff --git a/changelog.d/13603.misc b/changelog.d/13603.misc new file mode 100644 index 0000000000..d08eb6cc0a --- /dev/null +++ b/changelog.d/13603.misc @@ -0,0 +1 @@ +Rename `event_map` to `unpersisted_events` when computing the auth differences. diff --git a/synapse/state/v2.py b/synapse/state/v2.py index cf3045f82e..af03851c71 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -271,40 +271,41 @@ async def _get_power_level_for_sender( async def _get_auth_chain_difference( room_id: str, state_sets: Sequence[Mapping[Any, str]], - event_map: Dict[str, EventBase], + unpersisted_events: Dict[str, EventBase], state_res_store: StateResolutionStore, ) -> Set[str]: """Compare the auth chains of each state set and return the set of events - that only appear in some but not all of the auth chains. + that only appear in some, but not all of the auth chains. Args: - state_sets - event_map - state_res_store + state_sets: The input state sets we are trying to resolve across. + unpersisted_events: A map from event ID to EventBase containing all unpersisted + events involved in this resolution. + state_res_store: Returns: - Set of event IDs + The auth difference of the given state sets, as a set of event IDs. """ # The `StateResolutionStore.get_auth_chain_difference` function assumes that # all events passed to it (and their auth chains) have been persisted - # previously. This is not the case for any events in the `event_map`, and so - # we need to manually handle those events. + # previously. We need to manually handle any other events that are yet to be + # persisted. # - # We do this by: - # 1. calculating the auth chain difference for the state sets based on the - # events in `event_map` alone - # 2. replacing any events in the state_sets that are also in `event_map` - # with their auth events (recursively), and then calling - # `store.get_auth_chain_difference` as normal - # 3. adding the results of 1 and 2 together. - - # Map from event ID in `event_map` to their auth event IDs, and their auth - # event IDs if they appear in the `event_map`. This is the intersection of - # the event's auth chain with the events in the `event_map` *plus* their + # We do this in three steps: + # 1. Compute the set of unpersisted events belonging to the auth difference. + # 2. Replacing any unpersisted events in the state_sets with their auth events, + # recursively, until the state_sets contain only persisted events. + # Then we call `store.get_auth_chain_difference` as normal, which computes + # the set of persisted events belonging to the auth difference. + # 3. Adding the results of 1 and 2 together. + + # Map from event ID in `unpersisted_events` to their auth event IDs, and their auth + # event IDs if they appear in the `unpersisted_events`. This is the intersection of + # the event's auth chain with the events in `unpersisted_events` *plus* their # auth event IDs. events_to_auth_chain: Dict[str, Set[str]] = {} - for event in event_map.values(): + for event in unpersisted_events.values(): chain = {event.event_id} events_to_auth_chain[event.event_id] = chain @@ -312,16 +313,16 @@ async def _get_auth_chain_difference( while to_search: for auth_id in to_search.pop().auth_event_ids(): chain.add(auth_id) - auth_event = event_map.get(auth_id) + auth_event = unpersisted_events.get(auth_id) if auth_event: to_search.append(auth_event) - # We now a) calculate the auth chain difference for the unpersisted events - # and b) work out the state sets to pass to the store. + # We now 1) calculate the auth chain difference for the unpersisted events + # and 2) work out the state sets to pass to the store. # - # Note: If the `event_map` is empty (which is the common case), we can do a + # Note: If there are no `unpersisted_events` (which is the common case), we can do a # much simpler calculation. - if event_map: + if unpersisted_events: # The list of state sets to pass to the store, where each state set is a set # of the event ids making up the state. This is similar to `state_sets`, # except that (a) we only have event ids, not the complete @@ -344,14 +345,18 @@ async def _get_auth_chain_difference( for event_id in state_set.values(): event_chain = events_to_auth_chain.get(event_id) if event_chain is not None: - # We have an event in `event_map`. We add all the auth - # events that it references (that aren't also in `event_map`). - set_ids.update(e for e in event_chain if e not in event_map) + # We have an unpersisted event. We add all the auth + # events that it references which are also unpersisted. + set_ids.update( + e for e in event_chain if e not in unpersisted_events + ) # We also add the full chain of unpersisted event IDs # referenced by this state set, so that we can work out the # auth chain difference of the unpersisted events. - unpersisted_ids.update(e for e in event_chain if e in event_map) + unpersisted_ids.update( + e for e in event_chain if e in unpersisted_events + ) else: set_ids.add(event_id) @@ -361,15 +366,15 @@ async def _get_auth_chain_difference( union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) - difference_from_event_map: Collection[str] = union - intersection + auth_difference_unpersisted_part: Collection[str] = union - intersection else: - difference_from_event_map = () + auth_difference_unpersisted_part = () state_sets_ids = [set(state_set.values()) for state_set in state_sets] difference = await state_res_store.get_auth_chain_difference( room_id, state_sets_ids ) - difference.update(difference_from_event_map) + difference.update(auth_difference_unpersisted_part) return difference -- cgit 1.5.1 From 42b11d5565ed026c7d71f433c69e7b7007f45918 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 31 Aug 2022 12:19:39 +0100 Subject: Remove cached wrap on `_get_joined_users_from_context` method (#13569) The method doesn't actually do any data fetching and the method that does, `_get_joined_profile_from_event_id`, has its own cache. Signed off by Nick @ Beeper (@Fizzadar). --- changelog.d/13569.removal | 1 + synapse/state/__init__.py | 2 +- synapse/storage/databases/main/roommember.py | 122 +++++++++------------------ 3 files changed, 40 insertions(+), 85 deletions(-) create mode 100644 changelog.d/13569.removal (limited to 'synapse/state') diff --git a/changelog.d/13569.removal b/changelog.d/13569.removal new file mode 100644 index 0000000000..af9d407671 --- /dev/null +++ b/changelog.d/13569.removal @@ -0,0 +1 @@ +Remove redundant `_get_joined_users_from_context` cache. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 3047e1b1ad..3787d35b24 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -232,7 +232,7 @@ class StateHandler: logger.debug("calling resolve_state_groups from get_current_user_ids_in_room") entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = await entry.get_state(self._state_storage_controller, StateFilter.all()) - return await self.store.get_joined_user_ids_from_state(room_id, state, entry) + return await self.store.get_joined_user_ids_from_state(room_id, state) async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 06500457bd..4f0adb136a 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -31,7 +31,6 @@ from typing import ( import attr from synapse.api.constants import EventTypes, Membership -from synapse.events import EventBase from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import ( run_as_background_process, @@ -883,96 +882,51 @@ class RoomMemberWorkerStore(EventsWorkerStore): return shared_room_ids or frozenset() async def get_joined_user_ids_from_state( - self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry" + self, room_id: str, state: StateMap[str] ) -> Set[str]: - state_group: Union[object, int] = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - assert state_group is not None - with Measure(self._clock, "get_joined_users_from_state"): - return await self._get_joined_user_ids_from_context( - room_id, state_group, state, context=state_entry - ) + """ + For a given set of state IDs, get a set of user IDs in the room. - @cached(num_args=2, iterable=True, max_entries=100000) - async def _get_joined_user_ids_from_context( - self, - room_id: str, - state_group: Union[object, int], - current_state_ids: StateMap[str], - event: Optional[EventBase] = None, - context: Optional["_StateCacheEntry"] = None, - ) -> Set[str]: - # We don't use `state_group`, it's there so that we can cache based - # on it. However, it's important that it's never None, since two current_states - # with a state_group of None are likely to be different. - assert state_group is not None + This method checks the local event cache, before calling + `_get_user_ids_from_membership_event_ids` for any uncached events. + """ - users_in_room = set() - member_event_ids = [ - e_id - for key, e_id in current_state_ids.items() - if key[0] == EventTypes.Member - ] - - if context is not None: - # If we have a context with a delta from a previous state group, - # check if we also have the result from the previous group in cache. - # If we do then we can reuse that result and simply update it with - # any membership changes in `delta_ids` - if context.prev_group and context.delta_ids: - prev_res = self._get_joined_user_ids_from_context.cache.get_immediate( - (room_id, context.prev_group), None - ) - if prev_res and isinstance(prev_res, set): - users_in_room = prev_res - member_event_ids = [ - e_id - for key, e_id in context.delta_ids.items() - if key[0] == EventTypes.Member - ] - for etype, state_key in context.delta_ids: - if etype == EventTypes.Member: - users_in_room.discard(state_key) - - # We check if we have any of the member event ids in the event cache - # before we ask the DB - - # We don't update the event cache hit ratio as it completely throws off - # the hit ratio counts. After all, we don't populate the cache if we - # miss it here - event_map = self._get_events_from_local_cache( - member_event_ids, update_metrics=False - ) + with Measure(self._clock, "get_joined_user_ids_from_state"): + users_in_room = set() + member_event_ids = [ + e_id for key, e_id in state.items() if key[0] == EventTypes.Member + ] - missing_member_event_ids = [] - for event_id in member_event_ids: - ev_entry = event_map.get(event_id) - if ev_entry and not ev_entry.event.rejected_reason: - if ev_entry.event.membership == Membership.JOIN: - users_in_room.add(ev_entry.event.state_key) - else: - missing_member_event_ids.append(event_id) + # We check if we have any of the member event ids in the event cache + # before we ask the DB - if missing_member_event_ids: - event_to_memberships = await self._get_user_ids_from_membership_event_ids( - missing_member_event_ids - ) - users_in_room.update( - user_id for user_id in event_to_memberships.values() if user_id + # We don't update the event cache hit ratio as it completely throws off + # the hit ratio counts. After all, we don't populate the cache if we + # miss it here + event_map = self._get_events_from_local_cache( + member_event_ids, update_metrics=False ) - if event is not None and event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - if event.event_id in member_event_ids: - users_in_room.add(event.state_key) + missing_member_event_ids = [] + for event_id in member_event_ids: + ev_entry = event_map.get(event_id) + if ev_entry and not ev_entry.event.rejected_reason: + if ev_entry.event.membership == Membership.JOIN: + users_in_room.add(ev_entry.event.state_key) + else: + missing_member_event_ids.append(event_id) + + if missing_member_event_ids: + event_to_memberships = ( + await self._get_user_ids_from_membership_event_ids( + missing_member_event_ids + ) + ) + users_in_room.update( + user_id for user_id in event_to_memberships.values() if user_id + ) - return users_in_room + return users_in_room @cached( max_entries=10000, @@ -1205,7 +1159,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # The cache doesn't match the state group or prev state group, # so we calculate the result from first principles. joined_user_ids = await self.get_joined_user_ids_from_state( - room_id, state, state_entry + room_id, state ) cache.hosts_to_joined_users = {} -- cgit 1.5.1 From b73cbb82157d9666e8d667733afebc0d09ed858c Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 16 Sep 2022 12:45:04 +0100 Subject: Avoid putting rejected events in room state (#13723) Signed-off-by: Sean Quah --- changelog.d/13723.bugfix | 1 + synapse/state/v2.py | 15 ++ tests/handlers/test_federation_event.py | 399 ++++++++++++++++++++++++++++++++ 3 files changed, 415 insertions(+) create mode 100644 changelog.d/13723.bugfix (limited to 'synapse/state') diff --git a/changelog.d/13723.bugfix b/changelog.d/13723.bugfix new file mode 100644 index 0000000000..a23174d31d --- /dev/null +++ b/changelog.d/13723.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where previously rejected events could end up in room state because they pass auth checks given the current state of the room. diff --git a/synapse/state/v2.py b/synapse/state/v2.py index af03851c71..1b9d7d8457 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -577,6 +577,21 @@ async def _iterative_auth_checks( if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] + if event.rejected_reason is not None: + # Do not admit previously rejected events into state. + # TODO: This isn't spec compliant. Events that were previously rejected due + # to failing auth checks at their state, but pass auth checks during + # state resolution should be accepted. Synapse does not handle the + # change of rejection status well, so we preserve the previous + # rejection status for now. + # + # Note that events rejected for non-state reasons, such as having the + # wrong auth events, should remain rejected. + # + # https://spec.matrix.org/v1.2/rooms/v9/#rejected-events + # https://github.com/matrix-org/synapse/issues/13797 + continue + try: event_auth.check_state_dependent_auth_rules( event, diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index b5b89405a4..918010cddb 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -11,14 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from unittest import mock +from synapse.api.errors import AuthError +from synapse.api.room_versions import RoomVersion +from synapse.event_auth import ( + check_state_dependent_auth_rules, + check_state_independent_auth_rules, +) from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext from synapse.federation.transport.client import StateRequestResponse from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort +from synapse.types import JsonDict from tests import unittest from tests.test_utils import event_injection, make_awaitable @@ -449,3 +458,393 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): main_store.get_event(pulled_event.event_id, allow_none=True) ) self.assertIsNotNone(persisted, "pulled event was not persisted at all") + + def test_process_pulled_event_with_rejected_missing_state(self) -> None: + """Ensure that we correctly handle pulled events with missing state containing a + rejected state event + + In this test, we pretend we are processing a "pulled" event (eg, via backfill + or get_missing_events). The pulled event has a prev_event we haven't previously + seen, so the server requests the state at that prev_event. We expect the server + to make a /state request. + + We simulate a remote server whose /state includes a rejected kick event for a + local user. Notably, the kick event is rejected only because it cites a rejected + auth event and would otherwise be accepted based on the room state. During state + resolution, we re-run auth and can potentially introduce such rejected events + into the state if we are not careful. + + We check that the pulled event is correctly persisted, and that the state + afterwards does not include the rejected kick. + """ + # The DAG we are testing looks like: + # + # ... + # | + # v + # remote admin user joins + # | | + # +-------+ +-------+ + # | | + # | rejected power levels + # | from remote server + # | | + # | v + # | rejected kick of local user + # v from remote server + # new power levels | + # | v + # | missing event + # | from remote server + # | | + # +-------+ +-------+ + # | | + # v v + # pulled event + # from remote server + # + # (arrows are in the opposite direction to prev_events.) + + OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" + main_store = self.hs.get_datastores().main + + # Create the room. + kermit_user_id = self.register_user("kermit", "test") + kermit_tok = self.login("kermit", "test") + room_id = self.helper.create_room_as( + room_creator=kermit_user_id, tok=kermit_tok + ) + room_version = self.get_success(main_store.get_room_version(room_id)) + + # Add another local user to the room. This user is going to be kicked in a + # rejected event. + bert_user_id = self.register_user("bert", "test") + bert_tok = self.login("bert", "test") + self.helper.join(room_id, user=bert_user_id, tok=bert_tok) + + # Allow the remote user to kick bert. + # The remote user is going to send a rejected power levels event later on and we + # need state resolution to order it before another power levels event kermit is + # going to send later on. Hence we give both users the same power level, so that + # ties are broken by `origin_server_ts`. + self.helper.send_state( + room_id, + "m.room.power_levels", + {"users": {kermit_user_id: 100, OTHER_USER: 100}}, + tok=kermit_tok, + ) + + # Add the remote user to the room. + other_member_event = self.get_success( + event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") + ) + + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) + create_event = self.get_success( + main_store.get_event(initial_state_map[("m.room.create", "")]) + ) + bert_member_event = self.get_success( + main_store.get_event(initial_state_map[("m.room.member", bert_user_id)]) + ) + power_levels_event = self.get_success( + main_store.get_event(initial_state_map[("m.room.power_levels", "")]) + ) + + # We now need a rejected state event that will fail + # `check_state_independent_auth_rules` but pass + # `check_state_dependent_auth_rules`. + + # First, we create a power levels event that we pretend the remote server has + # accepted, but the local homeserver will reject. + next_depth = 100 + next_timestamp = other_member_event.origin_server_ts + 100 + rejected_power_levels_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "m.room.power_levels", + "state_key": "", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [other_member_event.event_id], + "auth_events": [ + initial_state_map[("m.room.create", "")], + initial_state_map[("m.room.power_levels", "")], + # The event will be rejected because of the duplicated auth + # event. + other_member_event.event_id, + other_member_event.event_id, + ], + "origin_server_ts": next_timestamp, + "depth": next_depth, + "content": power_levels_event.content, + } + ), + room_version, + ) + next_depth += 1 + next_timestamp += 100 + + with LoggingContext("send_rejected_power_levels_event"): + self.get_success( + self.hs.get_federation_event_handler()._process_pulled_event( + self.OTHER_SERVER_NAME, + rejected_power_levels_event, + backfilled=False, + ) + ) + self.assertEqual( + self.get_success( + main_store.get_rejection_reason( + rejected_power_levels_event.event_id + ) + ), + "auth_error", + ) + + # Then we create a kick event for a local user that cites the rejected power + # levels event in its auth events. The kick event will be rejected solely + # because of the rejected auth event and would otherwise be accepted. + rejected_kick_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "m.room.member", + "state_key": bert_user_id, + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [rejected_power_levels_event.event_id], + "auth_events": [ + initial_state_map[("m.room.create", "")], + rejected_power_levels_event.event_id, + initial_state_map[("m.room.member", bert_user_id)], + initial_state_map[("m.room.member", OTHER_USER)], + ], + "origin_server_ts": next_timestamp, + "depth": next_depth, + "content": {"membership": "leave"}, + } + ), + room_version, + ) + next_depth += 1 + next_timestamp += 100 + + # The kick event must fail the state-independent auth rules, but pass the + # state-dependent auth rules, so that it has a chance of making it through state + # resolution. + self.get_failure( + check_state_independent_auth_rules(main_store, rejected_kick_event), + AuthError, + ) + check_state_dependent_auth_rules( + rejected_kick_event, + [create_event, power_levels_event, other_member_event, bert_member_event], + ) + + # The kick event must also win over the original member event during state + # resolution. + self.assertEqual( + self.get_success( + _mainline_sort( + self.clock, + room_id, + event_ids=[ + bert_member_event.event_id, + rejected_kick_event.event_id, + ], + resolved_power_event_id=power_levels_event.event_id, + event_map={ + bert_member_event.event_id: bert_member_event, + rejected_kick_event.event_id: rejected_kick_event, + }, + state_res_store=main_store, + ) + ), + [bert_member_event.event_id, rejected_kick_event.event_id], + "The rejected kick event will not be applied after bert's join event " + "during state resolution. The test setup is incorrect.", + ) + + with LoggingContext("send_rejected_kick_event"): + self.get_success( + self.hs.get_federation_event_handler()._process_pulled_event( + self.OTHER_SERVER_NAME, rejected_kick_event, backfilled=False + ) + ) + self.assertEqual( + self.get_success( + main_store.get_rejection_reason(rejected_kick_event.event_id) + ), + "auth_error", + ) + + # We need another power levels event which will win over the rejected one during + # state resolution, otherwise we hit other issues where we end up with rejected + # a power levels event during state resolution. + self.reactor.advance(100) # ensure the `origin_server_ts` is larger + new_power_levels_event = self.get_success( + main_store.get_event( + self.helper.send_state( + room_id, + "m.room.power_levels", + {"users": {kermit_user_id: 100, OTHER_USER: 100, bert_user_id: 1}}, + tok=kermit_tok, + )["event_id"] + ) + ) + self.assertEqual( + self.get_success( + _reverse_topological_power_sort( + self.clock, + room_id, + event_ids=[ + new_power_levels_event.event_id, + rejected_power_levels_event.event_id, + ], + event_map={}, + state_res_store=main_store, + full_conflicted_set=set(), + ) + ), + [rejected_power_levels_event.event_id, new_power_levels_event.event_id], + "The power levels events will not have the desired ordering during state " + "resolution. The test setup is incorrect.", + ) + + # Create a missing event, so that the local homeserver has to do a `/state` or + # `/state_ids` request to pull state from the remote homeserver. + missing_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "m.room.message", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [rejected_kick_event.event_id], + "auth_events": [ + initial_state_map[("m.room.create", "")], + initial_state_map[("m.room.power_levels", "")], + initial_state_map[("m.room.member", OTHER_USER)], + ], + "origin_server_ts": next_timestamp, + "depth": next_depth, + "content": {"msgtype": "m.text", "body": "foo"}, + } + ), + room_version, + ) + next_depth += 1 + next_timestamp += 100 + + # The pulled event has two prev events, one of which is missing. We will make a + # `/state` or `/state_ids` request to the remote homeserver to ask it for the + # state before the missing prev event. + pulled_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "m.room.message", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [ + new_power_levels_event.event_id, + missing_event.event_id, + ], + "auth_events": [ + initial_state_map[("m.room.create", "")], + new_power_levels_event.event_id, + initial_state_map[("m.room.member", OTHER_USER)], + ], + "origin_server_ts": next_timestamp, + "depth": next_depth, + "content": {"msgtype": "m.text", "body": "bar"}, + } + ), + room_version, + ) + next_depth += 1 + next_timestamp += 100 + + # Prepare the response for the `/state` or `/state_ids` request. + # The remote server believes bert has been kicked, while the local server does + # not. + state_before_missing_event = self.get_success( + main_store.get_events_as_list(initial_state_map.values()) + ) + state_before_missing_event = [ + event + for event in state_before_missing_event + if event.event_id != bert_member_event.event_id + ] + state_before_missing_event.append(rejected_kick_event) + + # We have to bump the clock a bit, to keep the retry logic in + # `FederationClient.get_pdu` happy + self.reactor.advance(60000) + with LoggingContext("send_pulled_event"): + + async def get_event( + destination: str, event_id: str, timeout: Optional[int] = None + ) -> JsonDict: + self.assertEqual(destination, self.OTHER_SERVER_NAME) + self.assertEqual(event_id, missing_event.event_id) + return {"pdus": [missing_event.get_pdu_json()]} + + async def get_room_state_ids( + destination: str, room_id: str, event_id: str + ) -> JsonDict: + self.assertEqual(destination, self.OTHER_SERVER_NAME) + self.assertEqual(event_id, missing_event.event_id) + return { + "pdu_ids": [event.event_id for event in state_before_missing_event], + "auth_chain_ids": [], + } + + async def get_room_state( + room_version: RoomVersion, destination: str, room_id: str, event_id: str + ) -> StateRequestResponse: + self.assertEqual(destination, self.OTHER_SERVER_NAME) + self.assertEqual(event_id, missing_event.event_id) + return StateRequestResponse( + state=state_before_missing_event, + auth_events=[], + ) + + self.mock_federation_transport_client.get_event.side_effect = get_event + self.mock_federation_transport_client.get_room_state_ids.side_effect = ( + get_room_state_ids + ) + self.mock_federation_transport_client.get_room_state.side_effect = ( + get_room_state + ) + + self.get_success( + self.hs.get_federation_event_handler()._process_pulled_event( + self.OTHER_SERVER_NAME, pulled_event, backfilled=False + ) + ) + self.assertIsNone( + self.get_success( + main_store.get_rejection_reason(pulled_event.event_id) + ), + "Pulled event was unexpectedly rejected, likely due to a problem with " + "the test setup.", + ) + self.assertEqual( + {pulled_event.event_id}, + self.get_success( + main_store.have_events_in_timeline([pulled_event.event_id]) + ), + "Pulled event was not persisted, likely due to a problem with the test " + "setup.", + ) + + # We must not accept rejected events into the room state, so we expect bert + # to not be kicked, even if the remote server believes so. + new_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) + self.assertEqual( + new_state_map[("m.room.member", bert_user_id)], + bert_member_event.event_id, + "Rejected kick event unexpectedly became part of room state.", + ) -- cgit 1.5.1 From a2cf66a94d5dfd9d6496ac3e48ec9a22f17be69a Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 28 Sep 2022 02:39:03 -0700 Subject: Prepatory work for batching events to send (#13487) This PR begins work on batching up events during the creation of a room. The PR splits out the creation and sending/persisting of the events. The first three events in the creation of the room-creating the room, joining the creator to the room, and the power levels event are sent sequentially, while the subsequent events are created and collected to be sent at the end of the function. This is currently done by appending them to a list and then iterating over the list to send, the next step (after this PR) would be to send and persist the collected events as a batch. --- changelog.d/13487.misc | 1 + synapse/handlers/message.py | 175 ++++++++++++++++++++++++++-------------- synapse/handlers/room.py | 155 ++++++++++++++++++++++++----------- synapse/state/__init__.py | 63 +++++++++++++++ tests/rest/client/test_rooms.py | 4 +- 5 files changed, 290 insertions(+), 108 deletions(-) create mode 100644 changelog.d/13487.misc (limited to 'synapse/state') diff --git a/changelog.d/13487.misc b/changelog.d/13487.misc new file mode 100644 index 0000000000..761adc8b05 --- /dev/null +++ b/changelog.d/13487.misc @@ -0,0 +1 @@ +Speed up creation of DM rooms. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e07cda133a..062f93bc67 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -63,6 +63,7 @@ from synapse.types import ( MutableStateMap, Requester, RoomAlias, + StateMap, StreamToken, UserID, create_requester, @@ -567,9 +568,17 @@ class EventCreationHandler: outlier: bool = False, historical: bool = False, depth: Optional[int] = None, + state_map: Optional[StateMap[str]] = None, + for_batch: bool = False, + current_state_group: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """ - Given a dict from a client, create a new event. + Given a dict from a client, create a new event. If bool for_batch is true, will + create an event using the prev_event_ids, and will create an event context for + the event using the parameters state_map and current_state_group, thus these parameters + must be provided in this case if for_batch is True. The subsequently created event + and context are suitable for being batched up and bulk persisted to the database + with other similarly created events. Creates an FrozenEvent object, filling out auth_events, prev_events, etc. @@ -612,16 +621,27 @@ class EventCreationHandler: outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted back in time around some existing events. This is used to skip a few checks and mark the event as backfilled. + depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + state_map: A state map of previously created events, used only when creating events + for batch persisting + + for_batch: whether the event is being created for batch persisting to the db + + current_state_group: the current state group, used only for creating events for + batch persisting + Raises: ResourceLimitError if server is blocked to some resource being exceeded + Returns: Tuple of created event, Context """ @@ -693,6 +713,9 @@ class EventCreationHandler: auth_event_ids=auth_event_ids, state_event_ids=state_event_ids, depth=depth, + state_map=state_map, + for_batch=for_batch, + current_state_group=current_state_group, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -707,10 +730,14 @@ class EventCreationHandler: # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.Member, None)]) - ) - prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) + if for_batch: + assert state_map is not None + prev_event_id = state_map.get((EventTypes.Member, event.sender)) + else: + prev_state_ids = await context.get_prev_state_ids( + StateFilter.from_types([(EventTypes.Member, None)]) + ) + prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( await self.store.get_event(prev_event_id, allow_none=True) if prev_event_id @@ -1009,8 +1036,16 @@ class EventCreationHandler: auth_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + state_map: Optional[StateMap[str]] = None, + for_batch: bool = False, + current_state_group: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: - """Create a new event for a local client + """Create a new event for a local client. If bool for_batch is true, will + create an event using the prev_event_ids, and will create an event context for + the event using the parameters state_map and current_state_group, thus these parameters + must be provided in this case if for_batch is True. The subsequently created event + and context are suitable for being batched up and bulk persisted to the database + with other similarly created events. Args: builder: @@ -1043,6 +1078,14 @@ class EventCreationHandler: Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + state_map: A state map of previously created events, used only when creating events + for batch persisting + + for_batch: whether the event is being created for batch persisting to the db + + current_state_group: the current state group, used only for creating events for + batch persisting + Returns: Tuple of created event, context """ @@ -1095,64 +1138,76 @@ class EventCreationHandler: builder.type == EventTypes.Create or prev_event_ids ), "Attempting to create a non-m.room.create event with no prev_events" - event = await builder.build( - prev_event_ids=prev_event_ids, - auth_event_ids=auth_event_ids, - depth=depth, - ) + if for_batch: + assert prev_event_ids is not None + assert state_map is not None + assert current_state_group is not None + auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) + event = await builder.build( + prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth + ) + context = await self.state.compute_event_context_for_batched( + event, state_map, current_state_group + ) + else: + event = await builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + depth=depth, + ) - # Pass on the outlier property from the builder to the event - # after it is created - if builder.internal_metadata.outlier: - event.internal_metadata.outlier = True - context = EventContext.for_outlier(self._storage_controllers) - elif ( - event.type == EventTypes.MSC2716_INSERTION - and state_event_ids - and builder.internal_metadata.is_historical() - ): - # Add explicit state to the insertion event so it has state to derive - # from even though it's floating with no `prev_events`. The rest of - # the batch can derive from this state and state_group. - # - # TODO(faster_joins): figure out how this works, and make sure that the - # old state is complete. - # https://github.com/matrix-org/synapse/issues/13003 - metadata = await self.store.get_metadata_for_events(state_event_ids) - - state_map_for_event: MutableStateMap[str] = {} - for state_id in state_event_ids: - data = metadata.get(state_id) - if data is None: - # We're trying to persist a new historical batch of events - # with the given state, e.g. via - # `RoomBatchSendEventRestServlet`. The state can be inferred - # by Synapse or set directly by the client. - # - # Either way, we should have persisted all the state before - # getting here. - raise Exception( - f"State event {state_id} not found in DB," - " Synapse should have persisted it before using it." - ) + # Pass on the outlier property from the builder to the event + # after it is created + if builder.internal_metadata.outlier: + event.internal_metadata.outlier = True + context = EventContext.for_outlier(self._storage_controllers) + elif ( + event.type == EventTypes.MSC2716_INSERTION + and state_event_ids + and builder.internal_metadata.is_historical() + ): + # Add explicit state to the insertion event so it has state to derive + # from even though it's floating with no `prev_events`. The rest of + # the batch can derive from this state and state_group. + # + # TODO(faster_joins): figure out how this works, and make sure that the + # old state is complete. + # https://github.com/matrix-org/synapse/issues/13003 + metadata = await self.store.get_metadata_for_events(state_event_ids) + + state_map_for_event: MutableStateMap[str] = {} + for state_id in state_event_ids: + data = metadata.get(state_id) + if data is None: + # We're trying to persist a new historical batch of events + # with the given state, e.g. via + # `RoomBatchSendEventRestServlet`. The state can be inferred + # by Synapse or set directly by the client. + # + # Either way, we should have persisted all the state before + # getting here. + raise Exception( + f"State event {state_id} not found in DB," + " Synapse should have persisted it before using it." + ) - if data.state_key is None: - raise Exception( - f"Trying to set non-state event {state_id} as state" - ) + if data.state_key is None: + raise Exception( + f"Trying to set non-state event {state_id} as state" + ) - state_map_for_event[(data.event_type, data.state_key)] = state_id + state_map_for_event[(data.event_type, data.state_key)] = state_id - context = await self.state.compute_event_context( - event, - state_ids_before_event=state_map_for_event, - # TODO(faster_joins): check how MSC2716 works and whether we can have - # partial state here - # https://github.com/matrix-org/synapse/issues/13003 - partial_state=False, - ) - else: - context = await self.state.compute_event_context(event) + context = await self.state.compute_event_context( + event, + state_ids_before_event=state_map_for_event, + # TODO(faster_joins): check how MSC2716 works and whether we can have + # partial state here + # https://github.com/matrix-org/synapse/issues/13003 + partial_state=False, + ) + else: + context = await self.state.compute_event_context(event) if requester: context.app_service = requester.app_service diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 33e9a87002..09a1a82e6c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -716,7 +716,7 @@ class RoomCreationHandler: if ( self._server_notices_mxid is not None - and requester.user.to_string() == self._server_notices_mxid + and user_id == self._server_notices_mxid ): # allow the server notices mxid to create rooms is_requester_admin = True @@ -1042,7 +1042,9 @@ class RoomCreationHandler: creator_join_profile: Optional[JsonDict] = None, ratelimit: bool = True, ) -> Tuple[int, str, int]: - """Sends the initial events into a new room. + """Sends the initial events into a new room. Sends the room creation, membership, + and power level events into the room sequentially, then creates and batches up the + rest of the events to persist as a batch to the DB. `power_level_content_override` doesn't apply when initial state has power level state event content. @@ -1053,13 +1055,21 @@ class RoomCreationHandler: """ creator_id = creator.user.to_string() - event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} - depth = 1 + # the last event sent/persisted to the db last_sent_event_id: Optional[str] = None - - def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: + # the most recently created event + prev_event: List[str] = [] + # a map of event types, state keys -> event_ids. We collect these mappings this as events are + # created (but not persisted to the db) to determine state for future created events + # (as this info can't be pulled from the db) + state_map: MutableStateMap[str] = {} + # current_state_group of last event created. Used for computing event context of + # events to be batched + current_state_group = None + + def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: e = {"type": etype, "content": content} e.update(event_keys) @@ -1067,32 +1077,52 @@ class RoomCreationHandler: return e - async def send(etype: str, content: JsonDict, **kwargs: Any) -> int: - nonlocal last_sent_event_id + async def create_event( + etype: str, + content: JsonDict, + for_batch: bool, + **kwargs: Any, + ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: nonlocal depth + nonlocal prev_event - event = create(etype, content, **kwargs) - logger.debug("Sending %s in new room", etype) - # Allow these events to be sent even if the user is shadow-banned to - # allow the room creation to complete. - ( - sent_event, - last_stream_id, - ) = await self.event_creation_handler.create_and_send_nonmember_event( + event_dict = create_event_dict(etype, content, **kwargs) + + new_event, new_context = await self.event_creation_handler.create_event( creator, - event, + event_dict, + prev_event_ids=prev_event, + depth=depth, + state_map=state_map, + for_batch=for_batch, + current_state_group=current_state_group, + ) + depth += 1 + prev_event = [new_event.event_id] + state_map[(new_event.type, new_event.state_key)] = new_event.event_id + + return new_event, new_context + + async def send( + event: EventBase, + context: synapse.events.snapshot.EventContext, + creator: Requester, + ) -> int: + nonlocal last_sent_event_id + + ev = await self.event_creation_handler.handle_new_client_event( + requester=creator, + event=event, + context=context, ratelimit=False, ignore_shadow_ban=True, - # Note: we don't pass state_event_ids here because this triggers - # an additional query per event to look them up from the events table. - prev_event_ids=[last_sent_event_id] if last_sent_event_id else [], - depth=depth, ) - last_sent_event_id = sent_event.event_id - depth += 1 + last_sent_event_id = ev.event_id - return last_stream_id + # we know it was persisted, so must have a stream ordering + assert ev.internal_metadata.stream_ordering + return ev.internal_metadata.stream_ordering try: config = self._presets_dict[preset_config] @@ -1102,9 +1132,13 @@ class RoomCreationHandler: ) creation_content.update({"creator": creator_id}) - await send(etype=EventTypes.Create, content=creation_content) + creation_event, creation_context = await create_event( + EventTypes.Create, creation_content, False + ) logger.debug("Sending %s in new room", EventTypes.Member) + await send(creation_event, creation_context, creator) + # Room create event must exist at this point assert last_sent_event_id is not None member_event_id, _ = await self.room_member_handler.update_membership( @@ -1119,14 +1153,22 @@ class RoomCreationHandler: depth=depth, ) last_sent_event_id = member_event_id + prev_event = [member_event_id] + + # update the depth and state map here as the membership event has been created + # through a different code path + depth += 1 + state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - last_sent_stream_id = await send( - etype=EventTypes.PowerLevels, content=pl_content + power_event, power_context = await create_event( + EventTypes.PowerLevels, pl_content, False ) + current_state_group = power_context._state_group + last_sent_stream_id = await send(power_event, power_context, creator) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1169,47 +1211,68 @@ class RoomCreationHandler: # apply those. if power_level_content_override: power_level_content.update(power_level_content_override) - - last_sent_stream_id = await send( - etype=EventTypes.PowerLevels, content=power_level_content + pl_event, pl_context = await create_event( + EventTypes.PowerLevels, + power_level_content, + False, ) + current_state_group = pl_context._state_group + last_sent_stream_id = await send(pl_event, pl_context, creator) + events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.CanonicalAlias, - content={"alias": room_alias.to_string()}, + room_alias_event, room_alias_context = await create_event( + EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) + current_state_group = room_alias_context._state_group + events_to_send.append((room_alias_event, room_alias_context)) if (EventTypes.JoinRules, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} + join_rules_event, join_rules_context = await create_event( + EventTypes.JoinRules, + {"join_rule": config["join_rules"]}, + True, ) + current_state_group = join_rules_context._state_group + events_to_send.append((join_rules_event, join_rules_context)) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.RoomHistoryVisibility, - content={"history_visibility": config["history_visibility"]}, + visibility_event, visibility_context = await create_event( + EventTypes.RoomHistoryVisibility, + {"history_visibility": config["history_visibility"]}, + True, ) + current_state_group = visibility_context._state_group + events_to_send.append((visibility_event, visibility_context)) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.GuestAccess, - content={EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, + guest_access_event, guest_access_context = await create_event( + EventTypes.GuestAccess, + {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, + True, ) + current_state_group = guest_access_context._state_group + events_to_send.append((guest_access_event, guest_access_context)) for (etype, state_key), content in initial_state.items(): - last_sent_stream_id = await send( - etype=etype, state_key=state_key, content=content + event, context = await create_event( + etype, content, True, state_key=state_key ) + current_state_group = context._state_group + events_to_send.append((event, context)) if config["encrypted"]: - last_sent_stream_id = await send( - etype=EventTypes.RoomEncryption, + encryption_event, encryption_context = await create_event( + EventTypes.RoomEncryption, + {"algorithm": RoomEncryptionAlgorithms.DEFAULT}, + True, state_key="", - content={"algorithm": RoomEncryptionAlgorithms.DEFAULT}, ) + events_to_send.append((encryption_event, encryption_context)) + for event, context in events_to_send: + last_sent_stream_id = await send(event, context, creator) return last_sent_stream_id, last_sent_event_id, depth def _generate_room_id(self) -> str: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 3787d35b24..6f3dd0463e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -420,6 +420,69 @@ class StateHandler: partial_state=partial_state, ) + async def compute_event_context_for_batched( + self, + event: EventBase, + state_ids_before_event: StateMap[str], + current_state_group: int, + ) -> EventContext: + """ + Generate an event context for an event that has not yet been persisted to the + database. Intended for use with events that are created to be persisted in a batch. + Args: + event: the event the context is being computed for + state_ids_before_event: a state map consisting of the state ids of the events + created prior to this event. + current_state_group: the current state group before the event. + """ + state_group_before_event_prev_group = None + deltas_to_state_group_before_event = None + + state_group_before_event = current_state_group + + # if the event is not state, we are set + if not event.is_state(): + return EventContext.with_state( + storage=self._storage_controllers, + state_group_before_event=state_group_before_event, + state_group=state_group_before_event, + state_delta_due_to_event={}, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + partial_state=False, + ) + + # otherwise, we'll need to create a new state group for after the event + key = (event.type, event.state_key) + + if state_ids_before_event is not None: + replaces = state_ids_before_event.get(key) + + if replaces and replaces != event.event_id: + event.unsigned["replaces_state"] = replaces + + delta_ids = {key: event.event_id} + + state_group_after_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, + delta_ids=delta_ids, + current_state_ids=None, + ) + ) + + return EventContext.with_state( + storage=self._storage_controllers, + state_group=state_group_after_event, + state_group_before_event=state_group_before_event, + state_delta_due_to_event=delta_ids, + prev_group=state_group_before_event, + delta_ids=delta_ids, + partial_state=False, + ) + @measure_func() async def resolve_state_groups_for_events( self, room_id: str, event_ids: Collection[str], await_full_state: bool = True diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index c7eb88d33f..e281aef779 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -710,7 +710,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(44, channel.resource_usage.db_txn_count) + self.assertEqual(35, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -723,7 +723,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(50, channel.resource_usage.db_txn_count) + self.assertEqual(38, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id -- cgit 1.5.1 From 75888c2b1f5ec1c865c4690627bf101f7e0dffb9 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Thu, 17 Nov 2022 17:01:14 +0100 Subject: Faster joins: do not wait for full state when creating events to send (#14403) Signed-off-by: Mathieu Velten --- changelog.d/14403.misc | 1 + synapse/events/builder.py | 1 + synapse/state/__init__.py | 8 +++++++- 3 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14403.misc (limited to 'synapse/state') diff --git a/changelog.d/14403.misc b/changelog.d/14403.misc new file mode 100644 index 0000000000..ff28a2712a --- /dev/null +++ b/changelog.d/14403.misc @@ -0,0 +1 @@ +Faster joins: do not wait for full state when creating events to send. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index e2ee10dd3d..d62906043f 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -128,6 +128,7 @@ class EventBuilder: state_filter=StateFilter.from_types( auth_types_for_event(self.room_version, self) ), + await_full_state=False, ) auth_event_ids = self._event_auth_handler.compute_auth_events( self, state_ids diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 6f3dd0463e..833ffec3de 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -190,6 +190,7 @@ class StateHandler: room_id: str, event_ids: Collection[str], state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """Fetch the state after each of the given event IDs. Resolve them and return. @@ -200,13 +201,18 @@ class StateHandler: Args: room_id: the room_id containing the given events. event_ids: the events whose state should be fetched and resolved. + await_full_state: if `True`, will block if we do not yet have complete state + at the given `event_id`s, regardless of whether `state_filter` is + satisfied by partial state. Returns: the state dict (a mapping from (event_type, state_key) -> event_id) which holds the resolution of the states after the given event IDs. """ logger.debug("calling resolve_state_groups from compute_state_after_events") - ret = await self.resolve_state_groups_for_events(room_id, event_ids) + ret = await self.resolve_state_groups_for_events( + room_id, event_ids, await_full_state + ) return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_user_ids_in_room( -- cgit 1.5.1