summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12964.misc1
-rw-r--r--synapse/federation/sender/__init__.py6
-rw-r--r--synapse/handlers/typing.py7
-rw-r--r--synapse/state/__init__.py4
-rw-r--r--synapse/storage/_base.py1
-rw-r--r--synapse/storage/controllers/state.py8
-rw-r--r--synapse/storage/databases/main/roommember.py37
-rw-r--r--tests/federation/test_federation_sender.py14
-rw-r--r--tests/handlers/test_typing.py6
9 files changed, 68 insertions, 16 deletions
diff --git a/changelog.d/12964.misc b/changelog.d/12964.misc
new file mode 100644
index 0000000000..d57e1aca6b
--- /dev/null
+++ b/changelog.d/12964.misc
@@ -0,0 +1 @@
+Reduce the amount of state we pull from the DB.
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index dbe303ed9b..99a794c042 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -245,6 +245,8 @@ class FederationSender(AbstractFederationSender):
         self.store = hs.get_datastores().main
         self.state = hs.get_state_handler()
 
+        self._storage_controllers = hs.get_storage_controllers()
+
         self.clock = hs.get_clock()
         self.is_mine_id = hs.is_mine_id
 
@@ -602,7 +604,9 @@ class FederationSender(AbstractFederationSender):
         room_id = receipt.room_id
 
         # Work out which remote servers should be poked and poke them.
-        domains_set = await self.state.get_current_hosts_in_room(room_id)
+        domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
+            room_id
+        )
         domains = [
             d
             for d in domains_set
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 0aeab86bbb..d104ea07fe 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -59,6 +59,7 @@ class FollowerTypingHandler:
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self.server_name = hs.config.server.server_name
         self.clock = hs.get_clock()
         self.is_mine_id = hs.is_mine_id
@@ -131,7 +132,6 @@ class FollowerTypingHandler:
             return
 
         try:
-            users = await self.store.get_users_in_room(member.room_id)
             self._member_last_federation_poke[member] = self.clock.time_msec()
 
             now = self.clock.time_msec()
@@ -139,7 +139,10 @@ class FollowerTypingHandler:
                 now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
             )
 
-            for domain in {get_domain_from_id(u) for u in users}:
+            hosts = await self._storage_controllers.state.get_current_hosts_in_room(
+                member.room_id
+            )
+            for domain in hosts:
                 if domain != self.server_name:
                     logger.debug("sending typing update to %s", domain)
                     self.federation.build_and_send_edu(
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index ab68e2b6a4..da25f20ae5 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -172,10 +172,6 @@ class StateHandler:
         entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
         return await self.store.get_joined_users_from_state(room_id, entry)
 
-    async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
-        event_ids = await self.store.get_latest_event_ids_in_room(room_id)
-        return await self.get_hosts_in_room_at_events(room_id, event_ids)
-
     async def get_hosts_in_room_at_events(
         self, room_id: str, event_ids: Collection[str]
     ) -> FrozenSet[str]:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 57bd74700e..abfc56b061 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -71,6 +71,7 @@ class SQLBaseStore(metaclass=ABCMeta):
             self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
         if members_changed:
             self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
+            self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
             self._attempt_to_invalidate_cache(
                 "get_users_in_room_with_profiles", (room_id,)
             )
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 63a78ebc87..3b4cdb67eb 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -23,6 +23,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Set,
     Tuple,
 )
 
@@ -482,3 +483,10 @@ class StateStorageController:
             room_id, StateFilter.from_types((key,))
         )
         return state_map.get(key)
+
+    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+        """Get current hosts in room based on current state."""
+
+        await self._partial_state_room_tracker.await_full_state(room_id)
+
+        return await self.stores.main.get_current_hosts_in_room(room_id)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e222b7bd1f..31bc8c5601 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -893,6 +893,43 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return True
 
+    @cached(iterable=True, max_entries=10000)
+    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+        """Get current hosts in room based on current state."""
+
+        # First we check if we already have `get_users_in_room` in the cache, as
+        # we can just calculate result from that
+        users = self.get_users_in_room.cache.get_immediate(
+            (room_id,), None, update_metrics=False
+        )
+        if users is not None:
+            return {get_domain_from_id(u) for u in users}
+
+        if isinstance(self.database_engine, Sqlite3Engine):
+            # If we're using SQLite then let's just always use
+            # `get_users_in_room` rather than funky SQL.
+            users = await self.get_users_in_room(room_id)
+            return {get_domain_from_id(u) for u in users}
+
+        # For PostgreSQL we can use a regex to pull out the domains from the
+        # joined users in `current_state_events` via regex.
+
+        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+            sql = """
+                SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
+                FROM current_state_events
+                WHERE
+                    type = 'm.room.member'
+                    AND membership = 'join'
+                    AND room_id = ?
+            """
+            txn.execute(sql, (room_id,))
+            return {d for d, in txn}
+
+        return await self.db_pool.runInteraction(
+            "get_current_hosts_in_room", get_current_hosts_in_room_txn
+        )
+
     async def get_joined_hosts(
         self, room_id: str, state_entry: "_StateCacheEntry"
     ) -> FrozenSet[str]:
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index b5be727fe4..01a1db6115 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config
 
 class FederationSenderReceiptsTestCases(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
-        # Ensure a new Awaitable is created for each call.
-        mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
-            ["test", "host2"]
-        )
-        return self.setup_test_homeserver(
-            state_handler=mock_state_handler,
+        hs = self.setup_test_homeserver(
             federation_transport_client=Mock(spec=["send_transaction"]),
         )
 
+        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
+            return_value=make_awaitable({"test", "host2"})
+        )
+
+        return hs
+
     @override_config({"send_federation": True})
     def test_send_receipts(self):
         mock_send_transaction = (
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 14a0ee4922..7af1333126 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -129,10 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         hs.get_event_auth_handler().check_host_in_room = check_host_in_room
 
-        def get_joined_hosts_for_room(room_id: str):
+        async def get_current_hosts_in_room(room_id: str):
             return {member.domain for member in self.room_members}
 
-        self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
+        hs.get_storage_controllers().state.get_current_hosts_in_room = (
+            get_current_hosts_in_room
+        )
 
         async def get_users_in_room(room_id: str):
             return {str(u) for u in self.room_members}