summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8422.misc1
-rw-r--r--synapse/federation/federation_client.py4
-rw-r--r--synapse/handlers/federation.py13
-rw-r--r--synapse/storage/databases/state/store.py4
-rw-r--r--synapse/storage/persist_events.py2
-rw-r--r--synapse/storage/state.py6
6 files changed, 19 insertions, 11 deletions
diff --git a/changelog.d/8422.misc b/changelog.d/8422.misc
new file mode 100644
index 0000000000..03fba120c6
--- /dev/null
+++ b/changelog.d/8422.misc
@@ -0,0 +1 @@
+Typing fixes for `synapse.handlers.federation`.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 688d43fffb..302b2f69bc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -24,10 +24,12 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Sequence,
     Tuple,
     TypeVar,
+    Union,
 )
 
 from prometheus_client import Counter
@@ -501,7 +503,7 @@ class FederationClient(FederationBase):
         user_id: str,
         membership: str,
         content: dict,
-        params: Dict[str, str],
+        params: Optional[Mapping[str, Union[str, Iterable[str]]]],
     ) -> Tuple[str, EventBase, RoomVersion]:
         """
         Creates an m.room.member event, with context, without participating in the room.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 5bcfb231b2..0073e7c996 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -155,8 +155,9 @@ class FederationHandler(BaseHandler):
             self._device_list_updater = hs.get_device_handler().device_list_updater
             self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
 
-        # When joining a room we need to queue any events for that room up
-        self.room_queues = {}
+        # When joining a room we need to queue any events for that room up.
+        # For each room, a list of (pdu, origin) tuples.
+        self.room_queues = {}  # type: Dict[str, List[Tuple[EventBase, str]]]
         self._room_pdu_linearizer = Linearizer("fed_room_pdu")
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -814,6 +815,9 @@ class FederationHandler(BaseHandler):
             dest, room_id, limit=limit, extremities=extremities
         )
 
+        if not events:
+            return []
+
         # ideally we'd sanity check the events here for excess prev_events etc,
         # but it's hard to reject events at this point without completely
         # breaking backfill in the same way that it is currently broken by
@@ -2164,10 +2168,10 @@ class FederationHandler(BaseHandler):
             # given state at the event. This should correctly handle cases
             # like bans, especially with state res v2.
 
-            state_sets = await self.state_store.get_state_groups(
+            state_sets_d = await self.state_store.get_state_groups(
                 event.room_id, extrem_ids
             )
-            state_sets = list(state_sets.values())
+            state_sets = list(state_sets_d.values())  # type: List[Iterable[EventBase]]
             state_sets.append(state)
             current_states = await self.state_handler.resolve_events(
                 room_version, state_sets, event
@@ -2958,6 +2962,7 @@ class FederationHandler(BaseHandler):
             )
             return result["max_stream_id"]
         else:
+            assert self.storage.persistence
             max_stream_token = await self.storage.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 989f0cbc9d..0e31cc811a 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -24,7 +24,7 @@ from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStor
 from synapse.storage.state import StateFilter
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
 
@@ -208,7 +208,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
     async def _get_state_for_groups(
         self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
-    ) -> Dict[int, StateMap[str]]:
+    ) -> Dict[int, MutableStateMap[str]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 603cd7d825..ded6cf9655 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -197,7 +197,7 @@ class EventsPersistenceStorage:
 
     async def persist_events(
         self,
-        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
     ) -> RoomStreamToken:
         """
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 8f68d968f0..08a69f2f96 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -20,7 +20,7 @@ import attr
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
 
 logger = logging.getLogger(__name__)
 
@@ -349,7 +349,7 @@ class StateGroupStorage:
 
     async def get_state_groups_ids(
         self, _room_id: str, event_ids: Iterable[str]
-    ) -> Dict[int, StateMap[str]]:
+    ) -> Dict[int, MutableStateMap[str]]:
         """Get the event IDs of all the state for the state groups for the given events
 
         Args:
@@ -532,7 +532,7 @@ class StateGroupStorage:
 
     def _get_state_for_groups(
         self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
-    ) -> Awaitable[Dict[int, StateMap[str]]]:
+    ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key