diff options
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r-- | synapse/handlers/federation.py | 56 |
1 files changed, 30 insertions, 26 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 593932adb7..43f2986f89 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -72,7 +72,13 @@ from synapse.replication.http.federation import ( from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id +from synapse.types import ( + JsonDict, + MutableStateMap, + StateMap, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -96,7 +102,7 @@ class _NewEventInfo: event = attr.ib(type=EventBase) state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) + auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None) class FederationHandler(BaseHandler): @@ -434,11 +440,11 @@ class FederationHandler(BaseHandler): if not prevs - seen: return - latest = await self.store.get_latest_event_ids_in_room(room_id) + latest_list = await self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us - latest = set(latest) + latest = set(latest_list) latest |= seen logger.info( @@ -775,7 +781,7 @@ class FederationHandler(BaseHandler): # keys across all devices. current_keys = [ key - for device in cached_devices + for device in cached_devices.values() for key in device.get("keys", {}).get("keys", {}).values() ] @@ -1777,9 +1783,7 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups(room_id, [event_id]) @@ -1805,9 +1809,7 @@ class FederationHandler(BaseHandler): async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -1877,8 +1879,8 @@ class FederationHandler(BaseHandler): else: return None - def get_min_depth_for_context(self, context): - return self.store.get_min_depth(context) + async def get_min_depth_for_context(self, context): + return await self.store.get_min_depth(context) async def _handle_new_event( self, origin, event, state=None, auth_events=None, backfilled=False @@ -2057,7 +2059,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - auth_events: Optional[StateMap[EventBase]], + auth_events: Optional[MutableStateMap[EventBase]], backfilled: bool, ) -> EventContext: context = await self.state_handler.compute_event_context(event, old_state=state) @@ -2107,8 +2109,8 @@ class FederationHandler(BaseHandler): if backfilled or event.internal_metadata.is_outlier(): return - extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id) - extrem_ids = set(extrem_ids) + extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids = set(extrem_ids_list) prev_event_ids = set(event.prev_event_ids()) if extrem_ids == prev_event_ids: @@ -2138,10 +2140,12 @@ class FederationHandler(BaseHandler): ) state_sets = list(state_sets.values()) state_sets.append(state) - current_state_ids = await self.state_handler.resolve_events( + current_states = await self.state_handler.resolve_events( room_version, state_sets, event ) - current_state_ids = {k: e.event_id for k, e in current_state_ids.items()} + current_state_ids = { + k: e.event_id for k, e in current_states.items() + } # type: StateMap[str] else: current_state_ids = await self.state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids @@ -2153,11 +2157,13 @@ class FederationHandler(BaseHandler): # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) - current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] + current_state_ids_list = [ + e for k, e in current_state_ids.items() if k in auth_types + ] - current_auth_events = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids_list) current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() + (e.type, e.state_key): e for e in auth_events_map.values() } try: @@ -2173,9 +2179,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. @@ -2227,7 +2231,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, context: EventContext, - auth_events: StateMap[EventBase], + auth_events: MutableStateMap[EventBase], ) -> EventContext: """ @@ -2278,7 +2282,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, context: EventContext, - auth_events: StateMap[EventBase], + auth_events: MutableStateMap[EventBase], ) -> EventContext: """Helper for do_auth. See there for docs. |