summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/message.py9
-rw-r--r--synapse/state/__init__.py3
-rw-r--r--synapse/storage/controllers/state.py137
-rw-r--r--synapse/storage/databases/main/roommember.py122
4 files changed, 140 insertions, 131 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9910716bc6..fff0b5fa12 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1565,12 +1565,11 @@ class EventCreationHandler:
                 if state_entry.state_group in self._external_cache_joined_hosts_updates:
                     return
 
-                state = await state_entry.get_state(
-                    self._storage_controllers.state, StateFilter.all()
-                )
                 with opentracing.start_active_span("get_joined_hosts"):
-                    joined_hosts = await self.store.get_joined_hosts(
-                        event.room_id, state, state_entry
+                    joined_hosts = (
+                        await self._storage_controllers.state.get_joined_hosts(
+                            event.room_id, state_entry
+                        )
                     )
 
                 # Note that the expiry times must be larger than the expiry time in
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 9bc0c3b7b9..1b91cf5eaa 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -268,8 +268,7 @@ class StateHandler:
             The hosts in the room at the given events
         """
         entry = await self.resolve_state_groups_for_events(room_id, event_ids)
-        state = await entry.get_state(self._state_storage_controller, StateFilter.all())
-        return await self.store.get_joined_hosts(room_id, state, entry)
+        return await self._state_storage_controller.get_joined_hosts(room_id, entry)
 
     @trace
     @tag_args
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 233df7cce2..278c7832ba 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from itertools import chain
 from typing import (
     TYPE_CHECKING,
     AbstractSet,
@@ -19,14 +20,16 @@ from typing import (
     Callable,
     Collection,
     Dict,
+    FrozenSet,
     Iterable,
     List,
     Mapping,
     Optional,
     Tuple,
+    Union,
 )
 
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.logging.opentracing import tag_args, trace
 from synapse.storage.roommember import ProfileInfo
@@ -34,14 +37,20 @@ from synapse.storage.util.partial_state_events_tracker import (
     PartialCurrentStateTracker,
     PartialStateEventsTracker,
 )
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, get_domain_from_id
 from synapse.types.state import StateFilter
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches import intern_string
+from synapse.util.caches.descriptors import cached
 from synapse.util.cancellation import cancellable
+from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.state import _StateCacheEntry
     from synapse.storage.databases import Databases
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -52,10 +61,15 @@ class StateStorageController:
 
     def __init__(self, hs: "HomeServer", stores: "Databases"):
         self._is_mine_id = hs.is_mine_id
+        self._clock = hs.get_clock()
         self.stores = stores
         self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
         self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
 
+        # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
+        # at a time. Keyed by room_id.
+        self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
+
     def notify_event_un_partial_stated(self, event_id: str) -> None:
         self._partial_state_events_tracker.notify_un_partial_stated(event_id)
 
@@ -627,3 +641,122 @@ class StateStorageController:
         await self._partial_state_room_tracker.await_full_state(room_id)
 
         return await self.stores.main.get_users_in_room_with_profiles(room_id)
+
+    async def get_joined_hosts(
+        self, room_id: str, state_entry: "_StateCacheEntry"
+    ) -> FrozenSet[str]:
+        state_group: Union[object, int] = state_entry.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        assert state_group is not None
+        with Measure(self._clock, "get_joined_hosts"):
+            return await self._get_joined_hosts(
+                room_id, state_group, state_entry=state_entry
+            )
+
+    @cached(num_args=2, max_entries=10000, iterable=True)
+    async def _get_joined_hosts(
+        self,
+        room_id: str,
+        state_group: Union[object, int],
+        state_entry: "_StateCacheEntry",
+    ) -> FrozenSet[str]:
+        # We don't use `state_group`, it's there so that we can cache based on
+        # it. However, its important that its never None, since two
+        # current_state's with a state_group of None are likely to be different.
+        #
+        # The `state_group` must match the `state_entry.state_group` (if not None).
+        assert state_group is not None
+        assert state_entry.state_group is None or state_entry.state_group == state_group
+
+        # We use a secondary cache of previous work to allow us to build up the
+        # joined hosts for the given state group based on previous state groups.
+        #
+        # We cache one object per room containing the results of the last state
+        # group we got joined hosts for. The idea is that generally
+        # `get_joined_hosts` is called with the "current" state group for the
+        # room, and so consecutive calls will be for consecutive state groups
+        # which point to the previous state group.
+        cache = await self.stores.main._get_joined_hosts_cache(room_id)
+
+        # If the state group in the cache matches, we already have the data we need.
+        if state_entry.state_group == cache.state_group:
+            return frozenset(cache.hosts_to_joined_users)
+
+        # Since we'll mutate the cache we need to lock.
+        async with self._joined_host_linearizer.queue(room_id):
+            if state_entry.state_group == cache.state_group:
+                # Same state group, so nothing to do. We've already checked for
+                # this above, but the cache may have changed while waiting on
+                # the lock.
+                pass
+            elif state_entry.prev_group == cache.state_group:
+                # The cached work is for the previous state group, so we work out
+                # the delta.
+                assert state_entry.delta_ids is not None
+                for (typ, state_key), event_id in state_entry.delta_ids.items():
+                    if typ != EventTypes.Member:
+                        continue
+
+                    host = intern_string(get_domain_from_id(state_key))
+                    user_id = state_key
+                    known_joins = cache.hosts_to_joined_users.setdefault(host, set())
+
+                    event = await self.stores.main.get_event(event_id)
+                    if event.membership == Membership.JOIN:
+                        known_joins.add(user_id)
+                    else:
+                        known_joins.discard(user_id)
+
+                        if not known_joins:
+                            cache.hosts_to_joined_users.pop(host, None)
+            else:
+                # The cache doesn't match the state group or prev state group,
+                # so we calculate the result from first principles.
+                #
+                # We need to fetch all hosts joined to the room according to `state` by
+                # inspecting all join memberships in `state`. However, if the `state` is
+                # relatively recent then many of its events are likely to be held in
+                # the current state of the room, which is easily available and likely
+                # cached.
+                #
+                # We therefore compute the set of `state` events not in the
+                # current state and only fetch those.
+                current_memberships = (
+                    await self.stores.main._get_approximate_current_memberships_in_room(
+                        room_id
+                    )
+                )
+                unknown_state_events = {}
+                joined_users_in_current_state = []
+
+                state = await state_entry.get_state(
+                    self, StateFilter.from_types([(EventTypes.Member, None)])
+                )
+
+                for (type, state_key), event_id in state.items():
+                    if event_id not in current_memberships:
+                        unknown_state_events[type, state_key] = event_id
+                    elif current_memberships[event_id] == Membership.JOIN:
+                        joined_users_in_current_state.append(state_key)
+
+                joined_user_ids = await self.stores.main.get_joined_user_ids_from_state(
+                    room_id, unknown_state_events
+                )
+
+                cache.hosts_to_joined_users = {}
+                for user_id in chain(joined_user_ids, joined_users_in_current_state):
+                    host = intern_string(get_domain_from_id(user_id))
+                    cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+            if state_entry.state_group:
+                cache.state_group = state_entry.state_group
+            else:
+                cache.state_group = object()
+
+        return frozenset(cache.hosts_to_joined_users)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 582875c91a..fff259f74c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from itertools import chain
 from typing import (
     TYPE_CHECKING,
     AbstractSet,
@@ -57,15 +56,12 @@ from synapse.types import (
     StrCollection,
     get_domain_from_id,
 )
-from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
-    from synapse.state import _StateCacheEntry
 
 logger = logging.getLogger(__name__)
 
@@ -91,10 +87,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
-        # at a time. Keyed by room_id.
-        self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
-
         self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
 
         if (
@@ -1057,120 +1049,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
         )
 
-    async def get_joined_hosts(
-        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
-    ) -> FrozenSet[str]:
-        state_group: Union[object, int] = state_entry.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        assert state_group is not None
-        with Measure(self._clock, "get_joined_hosts"):
-            return await self._get_joined_hosts(
-                room_id, state_group, state, state_entry=state_entry
-            )
-
-    @cached(num_args=2, max_entries=10000, iterable=True)
-    async def _get_joined_hosts(
-        self,
-        room_id: str,
-        state_group: Union[object, int],
-        state: StateMap[str],
-        state_entry: "_StateCacheEntry",
-    ) -> FrozenSet[str]:
-        # We don't use `state_group`, it's there so that we can cache based on
-        # it. However, its important that its never None, since two
-        # current_state's with a state_group of None are likely to be different.
-        #
-        # The `state_group` must match the `state_entry.state_group` (if not None).
-        assert state_group is not None
-        assert state_entry.state_group is None or state_entry.state_group == state_group
-
-        # We use a secondary cache of previous work to allow us to build up the
-        # joined hosts for the given state group based on previous state groups.
-        #
-        # We cache one object per room containing the results of the last state
-        # group we got joined hosts for. The idea is that generally
-        # `get_joined_hosts` is called with the "current" state group for the
-        # room, and so consecutive calls will be for consecutive state groups
-        # which point to the previous state group.
-        cache = await self._get_joined_hosts_cache(room_id)
-
-        # If the state group in the cache matches, we already have the data we need.
-        if state_entry.state_group == cache.state_group:
-            return frozenset(cache.hosts_to_joined_users)
-
-        # Since we'll mutate the cache we need to lock.
-        async with self._joined_host_linearizer.queue(room_id):
-            if state_entry.state_group == cache.state_group:
-                # Same state group, so nothing to do. We've already checked for
-                # this above, but the cache may have changed while waiting on
-                # the lock.
-                pass
-            elif state_entry.prev_group == cache.state_group:
-                # The cached work is for the previous state group, so we work out
-                # the delta.
-                assert state_entry.delta_ids is not None
-                for (typ, state_key), event_id in state_entry.delta_ids.items():
-                    if typ != EventTypes.Member:
-                        continue
-
-                    host = intern_string(get_domain_from_id(state_key))
-                    user_id = state_key
-                    known_joins = cache.hosts_to_joined_users.setdefault(host, set())
-
-                    event = await self.get_event(event_id)
-                    if event.membership == Membership.JOIN:
-                        known_joins.add(user_id)
-                    else:
-                        known_joins.discard(user_id)
-
-                        if not known_joins:
-                            cache.hosts_to_joined_users.pop(host, None)
-            else:
-                # The cache doesn't match the state group or prev state group,
-                # so we calculate the result from first principles.
-                #
-                # We need to fetch all hosts joined to the room according to `state` by
-                # inspecting all join memberships in `state`. However, if the `state` is
-                # relatively recent then many of its events are likely to be held in
-                # the current state of the room, which is easily available and likely
-                # cached.
-                #
-                # We therefore compute the set of `state` events not in the
-                # current state and only fetch those.
-                current_memberships = (
-                    await self._get_approximate_current_memberships_in_room(room_id)
-                )
-                unknown_state_events = {}
-                joined_users_in_current_state = []
-
-                for (type, state_key), event_id in state.items():
-                    if event_id not in current_memberships:
-                        unknown_state_events[type, state_key] = event_id
-                    elif current_memberships[event_id] == Membership.JOIN:
-                        joined_users_in_current_state.append(state_key)
-
-                joined_user_ids = await self.get_joined_user_ids_from_state(
-                    room_id, unknown_state_events
-                )
-
-                cache.hosts_to_joined_users = {}
-                for user_id in chain(joined_user_ids, joined_users_in_current_state):
-                    host = intern_string(get_domain_from_id(user_id))
-                    cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
-
-            if state_entry.state_group:
-                cache.state_group = state_entry.state_group
-            else:
-                cache.state_group = object()
-
-        return frozenset(cache.hosts_to_joined_users)
-
     async def _get_approximate_current_memberships_in_room(
         self, room_id: str
     ) -> Mapping[str, Optional[str]]: