diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 233df7cce2..278c7832ba 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from itertools import chain
from typing import (
TYPE_CHECKING,
AbstractSet,
@@ -19,14 +20,16 @@ from typing import (
Callable,
Collection,
Dict,
+ FrozenSet,
Iterable,
List,
Mapping,
Optional,
Tuple,
+ Union,
)
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.logging.opentracing import tag_args, trace
from synapse.storage.roommember import ProfileInfo
@@ -34,14 +37,20 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
PartialStateEventsTracker,
)
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, get_domain_from_id
from synapse.types.state import StateFilter
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches import intern_string
+from synapse.util.caches.descriptors import cached
from synapse.util.cancellation import cancellable
+from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.state import _StateCacheEntry
from synapse.storage.databases import Databases
+
logger = logging.getLogger(__name__)
@@ -52,10 +61,15 @@ class StateStorageController:
def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
+ self._clock = hs.get_clock()
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
+ # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
+ # at a time. Keyed by room_id.
+ self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
+
def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
@@ -627,3 +641,122 @@ class StateStorageController:
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_users_in_room_with_profiles(room_id)
+
+ async def get_joined_hosts(
+ self, room_id: str, state_entry: "_StateCacheEntry"
+ ) -> FrozenSet[str]:
+ state_group: Union[object, int] = state_entry.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ assert state_group is not None
+ with Measure(self._clock, "get_joined_hosts"):
+ return await self._get_joined_hosts(
+ room_id, state_group, state_entry=state_entry
+ )
+
+ @cached(num_args=2, max_entries=10000, iterable=True)
+ async def _get_joined_hosts(
+ self,
+ room_id: str,
+ state_group: Union[object, int],
+ state_entry: "_StateCacheEntry",
+ ) -> FrozenSet[str]:
+ # We don't use `state_group`, it's there so that we can cache based on
+ # it. However, its important that its never None, since two
+ # current_state's with a state_group of None are likely to be different.
+ #
+ # The `state_group` must match the `state_entry.state_group` (if not None).
+ assert state_group is not None
+ assert state_entry.state_group is None or state_entry.state_group == state_group
+
+ # We use a secondary cache of previous work to allow us to build up the
+ # joined hosts for the given state group based on previous state groups.
+ #
+ # We cache one object per room containing the results of the last state
+ # group we got joined hosts for. The idea is that generally
+ # `get_joined_hosts` is called with the "current" state group for the
+ # room, and so consecutive calls will be for consecutive state groups
+ # which point to the previous state group.
+ cache = await self.stores.main._get_joined_hosts_cache(room_id)
+
+ # If the state group in the cache matches, we already have the data we need.
+ if state_entry.state_group == cache.state_group:
+ return frozenset(cache.hosts_to_joined_users)
+
+ # Since we'll mutate the cache we need to lock.
+ async with self._joined_host_linearizer.queue(room_id):
+ if state_entry.state_group == cache.state_group:
+ # Same state group, so nothing to do. We've already checked for
+ # this above, but the cache may have changed while waiting on
+ # the lock.
+ pass
+ elif state_entry.prev_group == cache.state_group:
+ # The cached work is for the previous state group, so we work out
+ # the delta.
+ assert state_entry.delta_ids is not None
+ for (typ, state_key), event_id in state_entry.delta_ids.items():
+ if typ != EventTypes.Member:
+ continue
+
+ host = intern_string(get_domain_from_id(state_key))
+ user_id = state_key
+ known_joins = cache.hosts_to_joined_users.setdefault(host, set())
+
+ event = await self.stores.main.get_event(event_id)
+ if event.membership == Membership.JOIN:
+ known_joins.add(user_id)
+ else:
+ known_joins.discard(user_id)
+
+ if not known_joins:
+ cache.hosts_to_joined_users.pop(host, None)
+ else:
+ # The cache doesn't match the state group or prev state group,
+ # so we calculate the result from first principles.
+ #
+ # We need to fetch all hosts joined to the room according to `state` by
+ # inspecting all join memberships in `state`. However, if the `state` is
+ # relatively recent then many of its events are likely to be held in
+ # the current state of the room, which is easily available and likely
+ # cached.
+ #
+ # We therefore compute the set of `state` events not in the
+ # current state and only fetch those.
+ current_memberships = (
+ await self.stores.main._get_approximate_current_memberships_in_room(
+ room_id
+ )
+ )
+ unknown_state_events = {}
+ joined_users_in_current_state = []
+
+ state = await state_entry.get_state(
+ self, StateFilter.from_types([(EventTypes.Member, None)])
+ )
+
+ for (type, state_key), event_id in state.items():
+ if event_id not in current_memberships:
+ unknown_state_events[type, state_key] = event_id
+ elif current_memberships[event_id] == Membership.JOIN:
+ joined_users_in_current_state.append(state_key)
+
+ joined_user_ids = await self.stores.main.get_joined_user_ids_from_state(
+ room_id, unknown_state_events
+ )
+
+ cache.hosts_to_joined_users = {}
+ for user_id in chain(joined_user_ids, joined_users_in_current_state):
+ host = intern_string(get_domain_from_id(user_id))
+ cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+ if state_entry.state_group:
+ cache.state_group = state_entry.state_group
+ else:
+ cache.state_group = object()
+
+ return frozenset(cache.hosts_to_joined_users)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 582875c91a..fff259f74c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from itertools import chain
from typing import (
TYPE_CHECKING,
AbstractSet,
@@ -57,15 +56,12 @@ from synapse.types import (
StrCollection,
get_domain_from_id,
)
-from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
- from synapse.state import _StateCacheEntry
logger = logging.getLogger(__name__)
@@ -91,10 +87,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
):
super().__init__(database, db_conn, hs)
- # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
- # at a time. Keyed by room_id.
- self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
-
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
if (
@@ -1057,120 +1049,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
)
- async def get_joined_hosts(
- self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
- ) -> FrozenSet[str]:
- state_group: Union[object, int] = state_entry.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- assert state_group is not None
- with Measure(self._clock, "get_joined_hosts"):
- return await self._get_joined_hosts(
- room_id, state_group, state, state_entry=state_entry
- )
-
- @cached(num_args=2, max_entries=10000, iterable=True)
- async def _get_joined_hosts(
- self,
- room_id: str,
- state_group: Union[object, int],
- state: StateMap[str],
- state_entry: "_StateCacheEntry",
- ) -> FrozenSet[str]:
- # We don't use `state_group`, it's there so that we can cache based on
- # it. However, its important that its never None, since two
- # current_state's with a state_group of None are likely to be different.
- #
- # The `state_group` must match the `state_entry.state_group` (if not None).
- assert state_group is not None
- assert state_entry.state_group is None or state_entry.state_group == state_group
-
- # We use a secondary cache of previous work to allow us to build up the
- # joined hosts for the given state group based on previous state groups.
- #
- # We cache one object per room containing the results of the last state
- # group we got joined hosts for. The idea is that generally
- # `get_joined_hosts` is called with the "current" state group for the
- # room, and so consecutive calls will be for consecutive state groups
- # which point to the previous state group.
- cache = await self._get_joined_hosts_cache(room_id)
-
- # If the state group in the cache matches, we already have the data we need.
- if state_entry.state_group == cache.state_group:
- return frozenset(cache.hosts_to_joined_users)
-
- # Since we'll mutate the cache we need to lock.
- async with self._joined_host_linearizer.queue(room_id):
- if state_entry.state_group == cache.state_group:
- # Same state group, so nothing to do. We've already checked for
- # this above, but the cache may have changed while waiting on
- # the lock.
- pass
- elif state_entry.prev_group == cache.state_group:
- # The cached work is for the previous state group, so we work out
- # the delta.
- assert state_entry.delta_ids is not None
- for (typ, state_key), event_id in state_entry.delta_ids.items():
- if typ != EventTypes.Member:
- continue
-
- host = intern_string(get_domain_from_id(state_key))
- user_id = state_key
- known_joins = cache.hosts_to_joined_users.setdefault(host, set())
-
- event = await self.get_event(event_id)
- if event.membership == Membership.JOIN:
- known_joins.add(user_id)
- else:
- known_joins.discard(user_id)
-
- if not known_joins:
- cache.hosts_to_joined_users.pop(host, None)
- else:
- # The cache doesn't match the state group or prev state group,
- # so we calculate the result from first principles.
- #
- # We need to fetch all hosts joined to the room according to `state` by
- # inspecting all join memberships in `state`. However, if the `state` is
- # relatively recent then many of its events are likely to be held in
- # the current state of the room, which is easily available and likely
- # cached.
- #
- # We therefore compute the set of `state` events not in the
- # current state and only fetch those.
- current_memberships = (
- await self._get_approximate_current_memberships_in_room(room_id)
- )
- unknown_state_events = {}
- joined_users_in_current_state = []
-
- for (type, state_key), event_id in state.items():
- if event_id not in current_memberships:
- unknown_state_events[type, state_key] = event_id
- elif current_memberships[event_id] == Membership.JOIN:
- joined_users_in_current_state.append(state_key)
-
- joined_user_ids = await self.get_joined_user_ids_from_state(
- room_id, unknown_state_events
- )
-
- cache.hosts_to_joined_users = {}
- for user_id in chain(joined_user_ids, joined_users_in_current_state):
- host = intern_string(get_domain_from_id(user_id))
- cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
-
- if state_entry.state_group:
- cache.state_group = state_entry.state_group
- else:
- cache.state_group = object()
-
- return frozenset(cache.hosts_to_joined_users)
-
async def _get_approximate_current_memberships_in_room(
self, room_id: str
) -> Mapping[str, Optional[str]]:
|