summary refs log tree commit diff
path: root/synapse/handlers/federation_event.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation_event.py')
-rw-r--r--synapse/handlers/federation_event.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 549b066dd9..87a0608359 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1584,9 +1584,11 @@ class FederationEventHandler:
         if guest_access == GuestAccess.CAN_JOIN:
             return
 
-        current_state_map = await self._state_handler.get_current_state(event.room_id)
-        current_state = list(current_state_map.values())
-        await self._get_room_member_handler().kick_guest_users(current_state)
+        current_state = await self._storage_controllers.state.get_current_state(
+            event.room_id
+        )
+        current_state_list = list(current_state.values())
+        await self._get_room_member_handler().kick_guest_users(current_state_list)
 
     async def _check_for_soft_fail(
         self,
@@ -1614,6 +1616,9 @@ class FederationEventHandler:
         room_version = await self._store.get_room_version_id(event.room_id)
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
+        # The event types we want to pull from the "current" state.
+        auth_types = auth_types_for_event(room_version_obj, event)
+
         # Calculate the "current state".
         if state_ids is not None:
             # If we're explicitly given the state then we won't have all the
@@ -1643,8 +1648,10 @@ class FederationEventHandler:
                 )
             )
         else:
-            current_state_ids = await self._state_handler.get_current_state_ids(
-                event.room_id, latest_event_ids=extrem_ids
+            current_state_ids = (
+                await self._state_storage_controller.get_current_state_ids(
+                    event.room_id, StateFilter.from_types(auth_types)
+                )
             )
 
         logger.debug(
@@ -1654,7 +1661,6 @@ class FederationEventHandler:
         )
 
         # Now check if event pass auth against said current state
-        auth_types = auth_types_for_event(room_version_obj, event)
         current_state_ids_list = [
             e for k, e in current_state_ids.items() if k in auth_types
         ]