summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-05-20 12:33:21 +0100
committerErik Johnston <erik@matrix.org>2022-05-20 12:51:53 +0100
commit11efe7231f4c5320493460ec41f4e9773c6ca120 (patch)
treeaf97ead1f5193cb20b727db17f0e7241f91d3afd
parentAdd helper methods to store (diff)
downloadsynapse-11efe7231f4c5320493460ec41f4e9773c6ca120.tar.xz
Use new helper functions
-rw-r--r--synapse/handlers/federation.py23
-rw-r--r--synapse/handlers/federation_event.py6
-rw-r--r--synapse/handlers/initial_sync.py4
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/room_member.py14
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/rest/admin/rooms.py23
-rw-r--r--tests/federation/test_federation_server.py4
-rw-r--r--tests/storage/test_events.py14
9 files changed, 60 insertions, 32 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0386d0a07b..3843c304f7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -20,7 +20,16 @@ import itertools
 import logging
 from enum import Enum
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -353,15 +362,11 @@ class FederationHandler:
         # First we try hosts that are already in the room
         # TODO: HEURISTIC ALERT.
 
-        curr_state = await self.state_handler.get_current_state(room_id)
-
-        curr_domains = get_domains_from_state(curr_state)
-
-        likely_domains = [
-            domain for domain, depth in curr_domains if domain != self.server_name
-        ]
+        users_in_room = await self.store.get_users_in_room(room_id)
+        likely_domains = {get_domain_from_id(u) for u in users_in_room}
+        likely_domains.discard(self.server_name)
 
-        async def try_backfill(domains: List[str]) -> bool:
+        async def try_backfill(domains: Collection[str]) -> bool:
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
                 try:
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 05c122f224..383242a4c9 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1558,9 +1558,9 @@ class FederationEventHandler:
         if guest_access == GuestAccess.CAN_JOIN:
             return
 
-        current_state_map = await self._state_handler.get_current_state(event.room_id)
-        current_state = list(current_state_map.values())
-        await self._get_room_member_handler().kick_guest_users(current_state)
+        current_state = await self._store.get_current_state(event.room_id)
+        current_state_list = list(current_state.values())
+        await self._get_room_member_handler().kick_guest_users(current_state_list)
 
     async def _check_for_soft_fail(
         self,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d79248ad90..7e6fc97e38 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -190,7 +190,7 @@ class InitialSyncHandler:
                 if event.membership == Membership.JOIN:
                     room_end_token = now_token.room_key
                     deferred_room_state = run_in_background(
-                        self.state_handler.get_current_state, event.room_id
+                        self.store.get_current_state, event.room_id
                     )
                 elif event.membership == Membership.LEAVE:
                     room_end_token = RoomStreamToken(
@@ -404,7 +404,7 @@ class InitialSyncHandler:
         membership: str,
         is_peeking: bool,
     ) -> JsonDict:
-        current_state = await self.state.get_current_state(room_id=room_id)
+        current_state = await self.store.get_current_state(room_id=room_id)
 
         # TODO: These concurrently
         time_now = self.clock.time_msec()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 92e1de0500..dca240bba4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1399,7 +1399,7 @@ class TimestampLookupHandler:
             )
 
             # Find other homeservers from the given state in the room
-            curr_state = await self.state_handler.get_current_state(room_id)
+            curr_state = await self.store.get_current_state(room_id)
             curr_domains = get_domains_from_state(curr_state)
             likely_domains = [
                 domain for domain, depth in curr_domains if domain != self.server_name
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8d6b255cf6..85fc1aedf3 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1409,7 +1409,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         txn_id: Optional[str],
         id_access_token: Optional[str] = None,
     ) -> int:
-        room_state = await self.state_handler.get_current_state(room_id)
+        room_state = await self.store.get_filtered_current_state(
+            room_id,
+            StateFilter.from_types(
+                [
+                    (EventTypes.Member, user.to_string()),
+                    (EventTypes.CanonicalAlias, ""),
+                    (EventTypes.Name, ""),
+                    (EventTypes.Create, ""),
+                    (EventTypes.JoinRules, ""),
+                    (EventTypes.RoomAvatar, ""),
+                ]
+            ),
+        )
 
         inviter_display_name = ""
         inviter_avatar_url = ""
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index cd1c47dae8..604b8b4944 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -348,7 +348,7 @@ class SearchHandler:
         state_results = {}
         if include_state:
             for room_id in {e.room_id for e in search_result.allowed_events}:
-                state = await self.state_handler.get_current_state(room_id)
+                state = await self.store.get_current_state(room_id)
                 state_results[room_id] = list(state.values())
 
         aggregations = await self._relations_handler.get_bundled_aggregations(
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index b05bda04cb..8c6b3d7fe7 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
     assert_user_is_admin,
 )
 from synapse.storage.databases.main.room import RoomSortOrder
+from synapse.storage.state import StateFilter
 from synapse.types import JsonDict, RoomID, UserID, create_requester
 from synapse.util import json_decoder
 
@@ -553,12 +554,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         user_to_add = content.get("user_id", requester.user.to_string())
 
         # Figure out which local users currently have power in the room, if any.
-        room_state = await self.state_handler.get_current_state(room_id)
-        if not room_state:
+        filtered_room_state = await self.store.get_filtered_current_state(
+            room_id,
+            StateFilter.from_types(
+                [
+                    (EventTypes.Create, ""),
+                    (EventTypes.PowerLevels, ""),
+                    (EventTypes.JoinRules, ""),
+                    (EventTypes.Member, user_to_add),
+                ]
+            ),
+        )
+        if not filtered_room_state:
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
 
-        create_event = room_state[(EventTypes.Create, "")]
-        power_levels = room_state.get((EventTypes.PowerLevels, ""))
+        create_event = filtered_room_state[(EventTypes.Create, "")]
+        power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
 
         if power_levels is not None:
             # We pick the local user with the highest power.
@@ -634,7 +645,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
 
         # Now we check if the user we're granting admin rights to is already in
         # the room. If not and it's not a public room we invite them.
-        member_event = room_state.get((EventTypes.Member, user_to_add))
+        member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
         is_joined = False
         if member_event:
             is_joined = member_event.content["membership"] in (
@@ -645,7 +656,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         if is_joined:
             return HTTPStatus.OK, {}
 
-        join_rules = room_state.get((EventTypes.JoinRules, ""))
+        join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
         is_public = False
         if join_rules:
             is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index b19365b81a..0f9fb0abf1 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -207,7 +207,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
 
         # the room should show that the new user is a member
         r = self.get_success(
-            self.hs.get_state_handler().get_current_state(self._room_id)
+            self.hs.get_datastores().main.get_current_state(self._room_id)
         )
         self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
 
@@ -258,7 +258,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
 
         # the room should show that the new user is a member
         r = self.get_success(
-            self.hs.get_state_handler().get_current_state(self._room_id)
+            self.hs.get_datastores().main.get_current_state(self._room_id)
         )
         self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
 
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index ef5e25873c..18a5907dab 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -103,7 +103,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
 
         self.persist_event(remote_event_2, state=state_before_gap.values())
 
@@ -135,7 +135,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         # setting. The state resolution across the old and new event will then
         # include it, and so the resolved state won't match the new state.
         state_before_gap = dict(
-            self.get_success(self.state.get_current_state(self.room_id))
+            self.get_success(self.store.get_current_state(self.room_id))
         )
         state_before_gap.pop(("m.room.history_visibility", ""))
 
@@ -177,7 +177,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
 
         self.persist_event(remote_event_2, state=state_before_gap.values())
 
@@ -207,7 +207,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
 
         self.persist_event(remote_event_2, state=state_before_gap.values())
 
@@ -247,7 +247,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
 
         self.persist_event(remote_event_2, state=state_before_gap.values())
 
@@ -289,7 +289,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
 
         self.persist_event(remote_event_2, state=state_before_gap.values())
 
@@ -323,7 +323,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
 
         self.persist_event(remote_event_2, state=state_before_gap.values())