diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index dbe303ed9b..99a794c042 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -245,6 +245,8 @@ class FederationSender(AbstractFederationSender):
self.store = hs.get_datastores().main
self.state = hs.get_state_handler()
+ self._storage_controllers = hs.get_storage_controllers()
+
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
@@ -602,7 +604,9 @@ class FederationSender(AbstractFederationSender):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
- domains_set = await self.state.get_current_hosts_in_room(room_id)
+ domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
+ room_id
+ )
domains = [
d
for d in domains_set
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 0aeab86bbb..d104ea07fe 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -59,6 +59,7 @@ class FollowerTypingHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.server_name = hs.config.server.server_name
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
@@ -131,7 +132,6 @@ class FollowerTypingHandler:
return
try:
- users = await self.store.get_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
@@ -139,7 +139,10 @@ class FollowerTypingHandler:
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
)
- for domain in {get_domain_from_id(u) for u in users}:
+ hosts = await self._storage_controllers.state.get_current_hosts_in_room(
+ member.room_id
+ )
+ for domain in hosts:
if domain != self.server_name:
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index ab68e2b6a4..da25f20ae5 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -172,10 +172,6 @@ class StateHandler:
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)
- async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
- event_ids = await self.store.get_latest_event_ids_in_room(room_id)
- return await self.get_hosts_in_room_at_events(room_id, event_ids)
-
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
) -> FrozenSet[str]:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 57bd74700e..abfc56b061 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -71,6 +71,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
if members_changed:
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
+ self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
self._attempt_to_invalidate_cache(
"get_users_in_room_with_profiles", (room_id,)
)
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 63a78ebc87..3b4cdb67eb 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -23,6 +23,7 @@ from typing import (
List,
Mapping,
Optional,
+ Set,
Tuple,
)
@@ -482,3 +483,10 @@ class StateStorageController:
room_id, StateFilter.from_types((key,))
)
return state_map.get(key)
+
+ async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ """Get current hosts in room based on current state."""
+
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ return await self.stores.main.get_current_hosts_in_room(room_id)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e222b7bd1f..31bc8c5601 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -893,6 +893,43 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
+ @cached(iterable=True, max_entries=10000)
+ async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ """Get current hosts in room based on current state."""
+
+ # First we check if we already have `get_users_in_room` in the cache, as
+ # we can just calculate result from that
+ users = self.get_users_in_room.cache.get_immediate(
+ (room_id,), None, update_metrics=False
+ )
+ if users is not None:
+ return {get_domain_from_id(u) for u in users}
+
+ if isinstance(self.database_engine, Sqlite3Engine):
+ # If we're using SQLite then let's just always use
+ # `get_users_in_room` rather than funky SQL.
+ users = await self.get_users_in_room(room_id)
+ return {get_domain_from_id(u) for u in users}
+
+ # For PostgreSQL we can use a regex to pull out the domains from the
+ # joined users in `current_state_events` via regex.
+
+ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+ sql = """
+ SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
+ FROM current_state_events
+ WHERE
+ type = 'm.room.member'
+ AND membership = 'join'
+ AND room_id = ?
+ """
+ txn.execute(sql, (room_id,))
+ return {d for d, in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_current_hosts_in_room", get_current_hosts_in_room_txn
+ )
+
async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
|