diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 582875c91a..fff259f74c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from itertools import chain
from typing import (
TYPE_CHECKING,
AbstractSet,
@@ -57,15 +56,12 @@ from synapse.types import (
StrCollection,
get_domain_from_id,
)
-from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
- from synapse.state import _StateCacheEntry
logger = logging.getLogger(__name__)
@@ -91,10 +87,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
):
super().__init__(database, db_conn, hs)
- # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
- # at a time. Keyed by room_id.
- self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
-
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
if (
@@ -1057,120 +1049,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
)
- async def get_joined_hosts(
- self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
- ) -> FrozenSet[str]:
- state_group: Union[object, int] = state_entry.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- assert state_group is not None
- with Measure(self._clock, "get_joined_hosts"):
- return await self._get_joined_hosts(
- room_id, state_group, state, state_entry=state_entry
- )
-
- @cached(num_args=2, max_entries=10000, iterable=True)
- async def _get_joined_hosts(
- self,
- room_id: str,
- state_group: Union[object, int],
- state: StateMap[str],
- state_entry: "_StateCacheEntry",
- ) -> FrozenSet[str]:
- # We don't use `state_group`, it's there so that we can cache based on
- # it. However, its important that its never None, since two
- # current_state's with a state_group of None are likely to be different.
- #
- # The `state_group` must match the `state_entry.state_group` (if not None).
- assert state_group is not None
- assert state_entry.state_group is None or state_entry.state_group == state_group
-
- # We use a secondary cache of previous work to allow us to build up the
- # joined hosts for the given state group based on previous state groups.
- #
- # We cache one object per room containing the results of the last state
- # group we got joined hosts for. The idea is that generally
- # `get_joined_hosts` is called with the "current" state group for the
- # room, and so consecutive calls will be for consecutive state groups
- # which point to the previous state group.
- cache = await self._get_joined_hosts_cache(room_id)
-
- # If the state group in the cache matches, we already have the data we need.
- if state_entry.state_group == cache.state_group:
- return frozenset(cache.hosts_to_joined_users)
-
- # Since we'll mutate the cache we need to lock.
- async with self._joined_host_linearizer.queue(room_id):
- if state_entry.state_group == cache.state_group:
- # Same state group, so nothing to do. We've already checked for
- # this above, but the cache may have changed while waiting on
- # the lock.
- pass
- elif state_entry.prev_group == cache.state_group:
- # The cached work is for the previous state group, so we work out
- # the delta.
- assert state_entry.delta_ids is not None
- for (typ, state_key), event_id in state_entry.delta_ids.items():
- if typ != EventTypes.Member:
- continue
-
- host = intern_string(get_domain_from_id(state_key))
- user_id = state_key
- known_joins = cache.hosts_to_joined_users.setdefault(host, set())
-
- event = await self.get_event(event_id)
- if event.membership == Membership.JOIN:
- known_joins.add(user_id)
- else:
- known_joins.discard(user_id)
-
- if not known_joins:
- cache.hosts_to_joined_users.pop(host, None)
- else:
- # The cache doesn't match the state group or prev state group,
- # so we calculate the result from first principles.
- #
- # We need to fetch all hosts joined to the room according to `state` by
- # inspecting all join memberships in `state`. However, if the `state` is
- # relatively recent then many of its events are likely to be held in
- # the current state of the room, which is easily available and likely
- # cached.
- #
- # We therefore compute the set of `state` events not in the
- # current state and only fetch those.
- current_memberships = (
- await self._get_approximate_current_memberships_in_room(room_id)
- )
- unknown_state_events = {}
- joined_users_in_current_state = []
-
- for (type, state_key), event_id in state.items():
- if event_id not in current_memberships:
- unknown_state_events[type, state_key] = event_id
- elif current_memberships[event_id] == Membership.JOIN:
- joined_users_in_current_state.append(state_key)
-
- joined_user_ids = await self.get_joined_user_ids_from_state(
- room_id, unknown_state_events
- )
-
- cache.hosts_to_joined_users = {}
- for user_id in chain(joined_user_ids, joined_users_in_current_state):
- host = intern_string(get_domain_from_id(user_id))
- cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
-
- if state_entry.state_group:
- cache.state_group = state_entry.state_group
- else:
- cache.state_group = object()
-
- return frozenset(cache.hosts_to_joined_users)
-
async def _get_approximate_current_memberships_in_room(
self, room_id: str
) -> Mapping[str, Optional[str]]:
|