summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py56
1 files changed, 30 insertions, 26 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 593932adb7..43f2986f89 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -72,7 +72,13 @@ from synapse.replication.http.federation import (
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
+from synapse.types import (
+    JsonDict,
+    MutableStateMap,
+    StateMap,
+    UserID,
+    get_domain_from_id,
+)
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
@@ -96,7 +102,7 @@ class _NewEventInfo:
 
     event = attr.ib(type=EventBase)
     state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
-    auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
+    auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None)
 
 
 class FederationHandler(BaseHandler):
@@ -434,11 +440,11 @@ class FederationHandler(BaseHandler):
         if not prevs - seen:
             return
 
-        latest = await self.store.get_latest_event_ids_in_room(room_id)
+        latest_list = await self.store.get_latest_event_ids_in_room(room_id)
 
         # We add the prev events that we have seen to the latest
         # list to ensure the remote server doesn't give them to us
-        latest = set(latest)
+        latest = set(latest_list)
         latest |= seen
 
         logger.info(
@@ -775,7 +781,7 @@ class FederationHandler(BaseHandler):
                     # keys across all devices.
                     current_keys = [
                         key
-                        for device in cached_devices
+                        for device in cached_devices.values()
                         for key in device.get("keys", {}).get("keys", {}).values()
                     ]
 
@@ -1777,9 +1783,7 @@ class FederationHandler(BaseHandler):
         """Returns the state at the event. i.e. not including said event.
         """
 
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         state_groups = await self.state_store.get_state_groups(room_id, [event_id])
 
@@ -1805,9 +1809,7 @@ class FederationHandler(BaseHandler):
     async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
         """Returns the state at the event. i.e. not including said event.
         """
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
 
@@ -1877,8 +1879,8 @@ class FederationHandler(BaseHandler):
         else:
             return None
 
-    def get_min_depth_for_context(self, context):
-        return self.store.get_min_depth(context)
+    async def get_min_depth_for_context(self, context):
+        return await self.store.get_min_depth(context)
 
     async def _handle_new_event(
         self, origin, event, state=None, auth_events=None, backfilled=False
@@ -2057,7 +2059,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         state: Optional[Iterable[EventBase]],
-        auth_events: Optional[StateMap[EventBase]],
+        auth_events: Optional[MutableStateMap[EventBase]],
         backfilled: bool,
     ) -> EventContext:
         context = await self.state_handler.compute_event_context(event, old_state=state)
@@ -2107,8 +2109,8 @@ class FederationHandler(BaseHandler):
         if backfilled or event.internal_metadata.is_outlier():
             return
 
-        extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
-        extrem_ids = set(extrem_ids)
+        extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
+        extrem_ids = set(extrem_ids_list)
         prev_event_ids = set(event.prev_event_ids())
 
         if extrem_ids == prev_event_ids:
@@ -2138,10 +2140,12 @@ class FederationHandler(BaseHandler):
             )
             state_sets = list(state_sets.values())
             state_sets.append(state)
-            current_state_ids = await self.state_handler.resolve_events(
+            current_states = await self.state_handler.resolve_events(
                 room_version, state_sets, event
             )
-            current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+            current_state_ids = {
+                k: e.event_id for k, e in current_states.items()
+            }  # type: StateMap[str]
         else:
             current_state_ids = await self.state_handler.get_current_state_ids(
                 event.room_id, latest_event_ids=extrem_ids
@@ -2153,11 +2157,13 @@ class FederationHandler(BaseHandler):
 
         # Now check if event pass auth against said current state
         auth_types = auth_types_for_event(event)
-        current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
+        current_state_ids_list = [
+            e for k, e in current_state_ids.items() if k in auth_types
+        ]
 
-        current_auth_events = await self.store.get_events(current_state_ids)
+        auth_events_map = await self.store.get_events(current_state_ids_list)
         current_auth_events = {
-            (e.type, e.state_key): e for e in current_auth_events.values()
+            (e.type, e.state_key): e for e in auth_events_map.values()
         }
 
         try:
@@ -2173,9 +2179,7 @@ class FederationHandler(BaseHandler):
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         # Just go through and process each event in `remote_auth_chain`. We
         # don't want to fall into the trap of `missing` being wrong.
@@ -2227,7 +2231,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         context: EventContext,
-        auth_events: StateMap[EventBase],
+        auth_events: MutableStateMap[EventBase],
     ) -> EventContext:
         """
 
@@ -2278,7 +2282,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         context: EventContext,
-        auth_events: StateMap[EventBase],
+        auth_events: MutableStateMap[EventBase],
     ) -> EventContext:
         """Helper for do_auth. See there for docs.