diff options
Diffstat (limited to 'synapse/storage/databases/main/roommember.py')
-rw-r--r-- | synapse/storage/databases/main/roommember.py | 161 |
1 files changed, 89 insertions, 72 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index bd8513cd43..2a8532f8c1 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -23,8 +23,11 @@ from typing import ( Optional, Set, Tuple, + Union, ) +import attr + from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.events.snapshot import EventContext @@ -43,7 +46,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import PersistedEventPosition, get_domain_from_id +from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -63,6 +66,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + # Used by `_get_joined_hosts` to ensure only one thing mutates the cache + # at a time. Keyed by room_id. + self._joined_host_linearizer = Linearizer("_JoinedHostsCache") + # Is the current_state_events.membership up to date? Or is the # background update still running? self._current_state_events_membership_up_to_date = False @@ -740,19 +747,82 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(num_args=2, max_entries=10000, iterable=True) async def _get_joined_hosts( - self, room_id, state_group, current_state_ids, state_entry - ): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. + self, + room_id: str, + state_group: int, + current_state_ids: StateMap[str], + state_entry: "_StateCacheEntry", + ) -> FrozenSet[str]: + # We don't use `state_group`, its there so that we can cache based on + # it. However, its important that its never None, since two + # current_state's with a state_group of None are likely to be different. + # + # The `state_group` must match the `state_entry.state_group` (if not None). assert state_group is not None - + assert state_entry.state_group is None or state_entry.state_group == state_group + + # We use a secondary cache of previous work to allow us to build up the + # joined hosts for the given state group based on previous state groups. + # + # We cache one object per room containing the results of the last state + # group we got joined hosts for. The idea is that generally + # `get_joined_hosts` is called with the "current" state group for the + # room, and so consecutive calls will be for consecutive state groups + # which point to the previous state group. cache = await self._get_joined_hosts_cache(room_id) - return await cache.get_destinations(state_entry) + + # If the state group in the cache matches, we already have the data we need. + if state_entry.state_group == cache.state_group: + return frozenset(cache.hosts_to_joined_users) + + # Since we'll mutate the cache we need to lock. + with (await self._joined_host_linearizer.queue(room_id)): + if state_entry.state_group == cache.state_group: + # Same state group, so nothing to do. We've already checked for + # this above, but the cache may have changed while waiting on + # the lock. + pass + elif state_entry.prev_group == cache.state_group: + # The cached work is for the previous state group, so we work out + # the delta. + for (typ, state_key), event_id in state_entry.delta_ids.items(): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = cache.hosts_to_joined_users.setdefault(host, set()) + + event = await self.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + cache.hosts_to_joined_users.pop(host, None) + else: + # The cache doesn't match the state group or prev state group, + # so we calculate the result from first principles. + joined_users = await self.get_joined_users_from_state( + room_id, state_entry + ) + + cache.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + cache.state_group = state_entry.state_group + else: + cache.state_group = object() + + return frozenset(cache.hosts_to_joined_users) @cached(max_entries=10000) def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": - return _JoinedHostsCache(self, room_id) + return _JoinedHostsCache() @cached(num_args=2) async def did_forget(self, user_id: str, room_id: str) -> bool: @@ -1062,71 +1132,18 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): await self.db_pool.runInteraction("forget_membership", f) +@attr.s(slots=True) class _JoinedHostsCache: - """Cache for joined hosts in a room that is optimised to handle updates - via state deltas. - """ - - def __init__(self, store, room_id): - self.store = store - self.room_id = room_id + """The cached data used by the `_get_joined_hosts_cache`.""" - self.hosts_to_joined_users = {} + # Dict of host to the set of their users in the room at the state group. + hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict) - self.state_group = object() - - self.linearizer = Linearizer("_JoinedHostsCache") - - self._len = 0 - - async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]: - """Get set of destinations for a state entry - - Args: - state_entry - - Returns: - The destinations as a set. - """ - if state_entry.state_group == self.state_group: - return frozenset(self.hosts_to_joined_users) - - with (await self.linearizer.queue(())): - if state_entry.state_group == self.state_group: - pass - elif state_entry.prev_group == self.state_group: - for (typ, state_key), event_id in state_entry.delta_ids.items(): - if typ != EventTypes.Member: - continue - - host = intern_string(get_domain_from_id(state_key)) - user_id = state_key - known_joins = self.hosts_to_joined_users.setdefault(host, set()) - - event = await self.store.get_event(event_id) - if event.membership == Membership.JOIN: - known_joins.add(user_id) - else: - known_joins.discard(user_id) - - if not known_joins: - self.hosts_to_joined_users.pop(host, None) - else: - joined_users = await self.store.get_joined_users_from_state( - self.room_id, state_entry - ) - - self.hosts_to_joined_users = {} - for user_id in joined_users: - host = intern_string(get_domain_from_id(user_id)) - self.hosts_to_joined_users.setdefault(host, set()).add(user_id) - - if state_entry.state_group: - self.state_group = state_entry.state_group - else: - self.state_group = object() - self._len = sum(len(v) for v in self.hosts_to_joined_users.values()) - return frozenset(self.hosts_to_joined_users) + # The state group `hosts_to_joined_users` is derived from. Will be an object + # if the instance is newly created or if the state is not based on a state + # group. (An object is used as a sentinel value to ensure that it never is + # equal to anything else). + state_group = attr.ib(type=Union[object, int], factory=object) def __len__(self): - return self._len + return sum(len(v) for v in self.hosts_to_joined_users.values()) |