diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bd8513cd43..2a8532f8c1 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -23,8 +23,11 @@ from typing import (
Optional,
Set,
Tuple,
+ Union,
)
+import attr
+
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@@ -43,7 +46,7 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import PersistedEventPosition, get_domain_from_id
+from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -63,6 +66,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
+ # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
+ # at a time. Keyed by room_id.
+ self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
+
# Is the current_state_events.membership up to date? Or is the
# background update still running?
self._current_state_events_membership_up_to_date = False
@@ -740,19 +747,82 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
- self, room_id, state_group, current_state_ids, state_entry
- ):
- # We don't use `state_group`, its there so that we can cache based
- # on it. However, its important that its never None, since two current_state's
- # with a state_group of None are likely to be different.
+ self,
+ room_id: str,
+ state_group: int,
+ current_state_ids: StateMap[str],
+ state_entry: "_StateCacheEntry",
+ ) -> FrozenSet[str]:
+ # We don't use `state_group`, its there so that we can cache based on
+ # it. However, its important that its never None, since two
+ # current_state's with a state_group of None are likely to be different.
+ #
+ # The `state_group` must match the `state_entry.state_group` (if not None).
assert state_group is not None
-
+ assert state_entry.state_group is None or state_entry.state_group == state_group
+
+ # We use a secondary cache of previous work to allow us to build up the
+ # joined hosts for the given state group based on previous state groups.
+ #
+ # We cache one object per room containing the results of the last state
+ # group we got joined hosts for. The idea is that generally
+ # `get_joined_hosts` is called with the "current" state group for the
+ # room, and so consecutive calls will be for consecutive state groups
+ # which point to the previous state group.
cache = await self._get_joined_hosts_cache(room_id)
- return await cache.get_destinations(state_entry)
+
+ # If the state group in the cache matches, we already have the data we need.
+ if state_entry.state_group == cache.state_group:
+ return frozenset(cache.hosts_to_joined_users)
+
+ # Since we'll mutate the cache we need to lock.
+ with (await self._joined_host_linearizer.queue(room_id)):
+ if state_entry.state_group == cache.state_group:
+ # Same state group, so nothing to do. We've already checked for
+ # this above, but the cache may have changed while waiting on
+ # the lock.
+ pass
+ elif state_entry.prev_group == cache.state_group:
+ # The cached work is for the previous state group, so we work out
+ # the delta.
+ for (typ, state_key), event_id in state_entry.delta_ids.items():
+ if typ != EventTypes.Member:
+ continue
+
+ host = intern_string(get_domain_from_id(state_key))
+ user_id = state_key
+ known_joins = cache.hosts_to_joined_users.setdefault(host, set())
+
+ event = await self.get_event(event_id)
+ if event.membership == Membership.JOIN:
+ known_joins.add(user_id)
+ else:
+ known_joins.discard(user_id)
+
+ if not known_joins:
+ cache.hosts_to_joined_users.pop(host, None)
+ else:
+ # The cache doesn't match the state group or prev state group,
+ # so we calculate the result from first principles.
+ joined_users = await self.get_joined_users_from_state(
+ room_id, state_entry
+ )
+
+ cache.hosts_to_joined_users = {}
+ for user_id in joined_users:
+ host = intern_string(get_domain_from_id(user_id))
+ cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+ if state_entry.state_group:
+ cache.state_group = state_entry.state_group
+ else:
+ cache.state_group = object()
+
+ return frozenset(cache.hosts_to_joined_users)
@cached(max_entries=10000)
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
- return _JoinedHostsCache(self, room_id)
+ return _JoinedHostsCache()
@cached(num_args=2)
async def did_forget(self, user_id: str, room_id: str) -> bool:
@@ -1062,71 +1132,18 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
await self.db_pool.runInteraction("forget_membership", f)
+@attr.s(slots=True)
class _JoinedHostsCache:
- """Cache for joined hosts in a room that is optimised to handle updates
- via state deltas.
- """
-
- def __init__(self, store, room_id):
- self.store = store
- self.room_id = room_id
+ """The cached data used by the `_get_joined_hosts_cache`."""
- self.hosts_to_joined_users = {}
+ # Dict of host to the set of their users in the room at the state group.
+ hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict)
- self.state_group = object()
-
- self.linearizer = Linearizer("_JoinedHostsCache")
-
- self._len = 0
-
- async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
- """Get set of destinations for a state entry
-
- Args:
- state_entry
-
- Returns:
- The destinations as a set.
- """
- if state_entry.state_group == self.state_group:
- return frozenset(self.hosts_to_joined_users)
-
- with (await self.linearizer.queue(())):
- if state_entry.state_group == self.state_group:
- pass
- elif state_entry.prev_group == self.state_group:
- for (typ, state_key), event_id in state_entry.delta_ids.items():
- if typ != EventTypes.Member:
- continue
-
- host = intern_string(get_domain_from_id(state_key))
- user_id = state_key
- known_joins = self.hosts_to_joined_users.setdefault(host, set())
-
- event = await self.store.get_event(event_id)
- if event.membership == Membership.JOIN:
- known_joins.add(user_id)
- else:
- known_joins.discard(user_id)
-
- if not known_joins:
- self.hosts_to_joined_users.pop(host, None)
- else:
- joined_users = await self.store.get_joined_users_from_state(
- self.room_id, state_entry
- )
-
- self.hosts_to_joined_users = {}
- for user_id in joined_users:
- host = intern_string(get_domain_from_id(user_id))
- self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
-
- if state_entry.state_group:
- self.state_group = state_entry.state_group
- else:
- self.state_group = object()
- self._len = sum(len(v) for v in self.hosts_to_joined_users.values())
- return frozenset(self.hosts_to_joined_users)
+ # The state group `hosts_to_joined_users` is derived from. Will be an object
+ # if the instance is newly created or if the state is not based on a state
+ # group. (An object is used as a sentinel value to ensure that it never is
+ # equal to anything else).
+ state_group = attr.ib(type=Union[object, int], factory=object)
def __len__(self):
- return self._len
+ return sum(len(v) for v in self.hosts_to_joined_users.values())
|