summary refs log tree commit diff
path: root/synapse/storage/controllers/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/controllers/state.py')
-rw-r--r--synapse/storage/controllers/state.py137
1 files changed, 135 insertions, 2 deletions
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 233df7cce2..278c7832ba 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from itertools import chain
 from typing import (
     TYPE_CHECKING,
     AbstractSet,
@@ -19,14 +20,16 @@ from typing import (
     Callable,
     Collection,
     Dict,
+    FrozenSet,
     Iterable,
     List,
     Mapping,
     Optional,
     Tuple,
+    Union,
 )
 
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.logging.opentracing import tag_args, trace
 from synapse.storage.roommember import ProfileInfo
@@ -34,14 +37,20 @@ from synapse.storage.util.partial_state_events_tracker import (
     PartialCurrentStateTracker,
     PartialStateEventsTracker,
 )
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, get_domain_from_id
 from synapse.types.state import StateFilter
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches import intern_string
+from synapse.util.caches.descriptors import cached
 from synapse.util.cancellation import cancellable
+from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.state import _StateCacheEntry
     from synapse.storage.databases import Databases
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -52,10 +61,15 @@ class StateStorageController:
 
     def __init__(self, hs: "HomeServer", stores: "Databases"):
         self._is_mine_id = hs.is_mine_id
+        self._clock = hs.get_clock()
         self.stores = stores
         self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
         self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
 
+        # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
+        # at a time. Keyed by room_id.
+        self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
+
     def notify_event_un_partial_stated(self, event_id: str) -> None:
         self._partial_state_events_tracker.notify_un_partial_stated(event_id)
 
@@ -627,3 +641,122 @@ class StateStorageController:
         await self._partial_state_room_tracker.await_full_state(room_id)
 
         return await self.stores.main.get_users_in_room_with_profiles(room_id)
+
+    async def get_joined_hosts(
+        self, room_id: str, state_entry: "_StateCacheEntry"
+    ) -> FrozenSet[str]:
+        state_group: Union[object, int] = state_entry.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        assert state_group is not None
+        with Measure(self._clock, "get_joined_hosts"):
+            return await self._get_joined_hosts(
+                room_id, state_group, state_entry=state_entry
+            )
+
+    @cached(num_args=2, max_entries=10000, iterable=True)
+    async def _get_joined_hosts(
+        self,
+        room_id: str,
+        state_group: Union[object, int],
+        state_entry: "_StateCacheEntry",
+    ) -> FrozenSet[str]:
+        # We don't use `state_group`, it's there so that we can cache based on
+        # it. However, its important that its never None, since two
+        # current_state's with a state_group of None are likely to be different.
+        #
+        # The `state_group` must match the `state_entry.state_group` (if not None).
+        assert state_group is not None
+        assert state_entry.state_group is None or state_entry.state_group == state_group
+
+        # We use a secondary cache of previous work to allow us to build up the
+        # joined hosts for the given state group based on previous state groups.
+        #
+        # We cache one object per room containing the results of the last state
+        # group we got joined hosts for. The idea is that generally
+        # `get_joined_hosts` is called with the "current" state group for the
+        # room, and so consecutive calls will be for consecutive state groups
+        # which point to the previous state group.
+        cache = await self.stores.main._get_joined_hosts_cache(room_id)
+
+        # If the state group in the cache matches, we already have the data we need.
+        if state_entry.state_group == cache.state_group:
+            return frozenset(cache.hosts_to_joined_users)
+
+        # Since we'll mutate the cache we need to lock.
+        async with self._joined_host_linearizer.queue(room_id):
+            if state_entry.state_group == cache.state_group:
+                # Same state group, so nothing to do. We've already checked for
+                # this above, but the cache may have changed while waiting on
+                # the lock.
+                pass
+            elif state_entry.prev_group == cache.state_group:
+                # The cached work is for the previous state group, so we work out
+                # the delta.
+                assert state_entry.delta_ids is not None
+                for (typ, state_key), event_id in state_entry.delta_ids.items():
+                    if typ != EventTypes.Member:
+                        continue
+
+                    host = intern_string(get_domain_from_id(state_key))
+                    user_id = state_key
+                    known_joins = cache.hosts_to_joined_users.setdefault(host, set())
+
+                    event = await self.stores.main.get_event(event_id)
+                    if event.membership == Membership.JOIN:
+                        known_joins.add(user_id)
+                    else:
+                        known_joins.discard(user_id)
+
+                        if not known_joins:
+                            cache.hosts_to_joined_users.pop(host, None)
+            else:
+                # The cache doesn't match the state group or prev state group,
+                # so we calculate the result from first principles.
+                #
+                # We need to fetch all hosts joined to the room according to `state` by
+                # inspecting all join memberships in `state`. However, if the `state` is
+                # relatively recent then many of its events are likely to be held in
+                # the current state of the room, which is easily available and likely
+                # cached.
+                #
+                # We therefore compute the set of `state` events not in the
+                # current state and only fetch those.
+                current_memberships = (
+                    await self.stores.main._get_approximate_current_memberships_in_room(
+                        room_id
+                    )
+                )
+                unknown_state_events = {}
+                joined_users_in_current_state = []
+
+                state = await state_entry.get_state(
+                    self, StateFilter.from_types([(EventTypes.Member, None)])
+                )
+
+                for (type, state_key), event_id in state.items():
+                    if event_id not in current_memberships:
+                        unknown_state_events[type, state_key] = event_id
+                    elif current_memberships[event_id] == Membership.JOIN:
+                        joined_users_in_current_state.append(state_key)
+
+                joined_user_ids = await self.stores.main.get_joined_user_ids_from_state(
+                    room_id, unknown_state_events
+                )
+
+                cache.hosts_to_joined_users = {}
+                for user_id in chain(joined_user_ids, joined_users_in_current_state):
+                    host = intern_string(get_domain_from_id(user_id))
+                    cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+            if state_entry.state_group:
+                cache.state_group = state_entry.state_group
+            else:
+                cache.state_group = object()
+
+        return frozenset(cache.hosts_to_joined_users)