summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py1
-rw-r--r--synapse/storage/controllers/state.py8
-rw-r--r--synapse/storage/databases/main/roommember.py37
3 files changed, 46 insertions, 0 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 57bd74700e..abfc56b061 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -71,6 +71,7 @@ class SQLBaseStore(metaclass=ABCMeta):
             self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
         if members_changed:
             self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
+            self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
             self._attempt_to_invalidate_cache(
                 "get_users_in_room_with_profiles", (room_id,)
             )
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 63a78ebc87..3b4cdb67eb 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -23,6 +23,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Set,
     Tuple,
 )
 
@@ -482,3 +483,10 @@ class StateStorageController:
             room_id, StateFilter.from_types((key,))
         )
         return state_map.get(key)
+
+    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+        """Get current hosts in room based on current state."""
+
+        await self._partial_state_room_tracker.await_full_state(room_id)
+
+        return await self.stores.main.get_current_hosts_in_room(room_id)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e222b7bd1f..31bc8c5601 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -893,6 +893,43 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return True
 
+    @cached(iterable=True, max_entries=10000)
+    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+        """Get current hosts in room based on current state."""
+
+        # First we check if we already have `get_users_in_room` in the cache, as
+        # we can just calculate result from that
+        users = self.get_users_in_room.cache.get_immediate(
+            (room_id,), None, update_metrics=False
+        )
+        if users is not None:
+            return {get_domain_from_id(u) for u in users}
+
+        if isinstance(self.database_engine, Sqlite3Engine):
+            # If we're using SQLite then let's just always use
+            # `get_users_in_room` rather than funky SQL.
+            users = await self.get_users_in_room(room_id)
+            return {get_domain_from_id(u) for u in users}
+
+        # For PostgreSQL we can use a regex to pull out the domains from the
+        # joined users in `current_state_events` via regex.
+
+        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+            sql = """
+                SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
+                FROM current_state_events
+                WHERE
+                    type = 'm.room.member'
+                    AND membership = 'join'
+                    AND room_id = ?
+            """
+            txn.execute(sql, (room_id,))
+            return {d for d, in txn}
+
+        return await self.db_pool.runInteraction(
+            "get_current_hosts_in_room", get_current_hosts_in_room_txn
+        )
+
     async def get_joined_hosts(
         self, room_id: str, state_entry: "_StateCacheEntry"
     ) -> FrozenSet[str]: