summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13605.misc1
-rw-r--r--synapse/handlers/device.py8
-rw-r--r--synapse/handlers/directory.py12
-rw-r--r--synapse/handlers/presence.py3
-rw-r--r--synapse/handlers/typing.py7
-rw-r--r--tests/federation/test_federation_sender.py17
6 files changed, 31 insertions, 17 deletions
diff --git a/changelog.d/13605.misc b/changelog.d/13605.misc
new file mode 100644
index 0000000000..88d518383b
--- /dev/null
+++ b/changelog.d/13605.misc
@@ -0,0 +1 @@
+Refactor `get_users_in_room(room_id)` mis-use with dedicated `get_current_hosts_in_room(room_id)` function.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index f5c586f657..9c2c3a0e68 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -310,6 +310,7 @@ class DeviceHandler(DeviceWorkerHandler):
         super().__init__(hs)
 
         self.federation_sender = hs.get_federation_sender()
+        self._storage_controllers = hs.get_storage_controllers()
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -694,8 +695,11 @@ class DeviceHandler(DeviceWorkerHandler):
 
                     # Ignore any users that aren't ours
                     if self.hs.is_mine_id(user_id):
-                        joined_user_ids = await self.store.get_users_in_room(room_id)
-                        hosts = {get_domain_from_id(u) for u in joined_user_ids}
+                        hosts = set(
+                            await self._storage_controllers.state.get_current_hosts_in_room(
+                                room_id
+                            )
+                        )
                         hosts.discard(self.server_name)
 
                     # Check if we've already sent this update to some hosts
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 948f66a94d..7127d5aefc 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -30,7 +30,7 @@ from synapse.api.errors import (
 from synapse.appservice import ApplicationService
 from synapse.module_api import NOT_SPAM
 from synapse.storage.databases.main.directory import RoomAliasMapping
-from synapse.types import JsonDict, Requester, RoomAlias, get_domain_from_id
+from synapse.types import JsonDict, Requester, RoomAlias
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -83,8 +83,9 @@ class DirectoryHandler:
         # TODO(erikj): Add transactions.
         # TODO(erikj): Check if there is a current association.
         if not servers:
-            users = await self.store.get_users_in_room(room_id)
-            servers = {get_domain_from_id(u) for u in users}
+            servers = await self._storage_controllers.state.get_current_hosts_in_room(
+                room_id
+            )
 
         if not servers:
             raise SynapseError(400, "Failed to get server list")
@@ -287,8 +288,9 @@ class DirectoryHandler:
                 Codes.NOT_FOUND,
             )
 
-        users = await self.store.get_users_in_room(room_id)
-        extra_servers = {get_domain_from_id(u) for u in users}
+        extra_servers = await self._storage_controllers.state.get_current_hosts_in_room(
+            room_id
+        )
         servers_set = set(extra_servers) | set(servers)
 
         # If this server is in the list of servers, return it first.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 741504ba9f..4e575ffbaa 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -2051,8 +2051,7 @@ async def get_interested_remotes(
     )
 
     for room_id, states in room_ids_to_states.items():
-        user_ids = await store.get_users_in_room(room_id)
-        hosts = {get_domain_from_id(user_id) for user_id in user_ids}
+        hosts = await store.get_current_hosts_in_room(room_id)
         for host in hosts:
             hosts_and_states.setdefault(host, set()).update(states)
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index bcac3372a2..a4cd8b8f0c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.replication.tcp.streams import TypingStream
 from synapse.streams import EventSource
-from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -362,8 +362,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             )
             return
 
-        users = await self.store.get_users_in_room(room_id)
-        domains = {get_domain_from_id(u) for u in users}
+        domains = await self._storage_controllers.state.get_current_hosts_in_room(
+            room_id
+        )
 
         if self.server_name in domains:
             logger.info("Got typing update from %s: %r", user_id, content)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 01a1db6115..a5aa500ef8 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -173,17 +173,24 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         return c
 
     def prepare(self, reactor, clock, hs):
-        # stub out `get_rooms_for_user` and `get_users_in_room` so that the
+        test_room_id = "!room:host1"
+
+        # stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the
         # server thinks the user shares a room with `@user2:host2`
         def get_rooms_for_user(user_id):
-            return defer.succeed({"!room:host1"})
+            return defer.succeed({test_room_id})
 
         hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
 
-        def get_users_in_room(room_id):
-            return defer.succeed({"@user2:host2"})
+        async def get_current_hosts_in_room(room_id):
+            if room_id == test_room_id:
+                return ["host2"]
+
+            # TODO: We should fail the test when we encounter an unxpected room ID.
+            # We can't just use `self.fail(...)` here because the app code is greedy
+            # with `Exception` and will catch it before the test can see it.
 
-        hs.get_datastores().main.get_users_in_room = get_users_in_room
+        hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room
 
         # whenever send_transaction is called, record the edu data
         self.edus = []