diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ec8eb21674..49f8aa25ea 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
import random
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -66,7 +66,7 @@ logger = logging.getLogger(__name__)
class MessageHandler:
"""Contains some read only APIs to get state about a room"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
@@ -91,7 +91,7 @@ class MessageHandler:
room_id: str,
event_type: str,
state_key: str,
- ) -> dict:
+ ) -> Optional[EventBase]:
"""Get data from a room.
Args:
@@ -115,6 +115,10 @@ class MessageHandler:
data = await self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
+ # If the membership is not JOIN, then the event ID should exist.
+ assert (
+ membership_event_id is not None
+ ), "check_user_in_room_or_world_readable returned invalid data"
room_state = await self.state_store.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
@@ -186,10 +190,12 @@ class MessageHandler:
event = last_events[0]
if visible_events:
- room_state = await self.state_store.get_state_for_events(
+ room_state_events = await self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter
)
- room_state = room_state[event.event_id]
+ room_state = room_state_events[
+ event.event_id
+ ] # type: Mapping[Any, EventBase]
else:
raise AuthError(
403,
@@ -210,10 +216,14 @@ class MessageHandler:
)
room_state = await self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
- room_state = await self.state_store.get_state_for_events(
+ # If the membership is not JOIN, then the event ID should exist.
+ assert (
+ membership_event_id is not None
+ ), "check_user_in_room_or_world_readable returned invalid data"
+ room_state_events = await self.state_store.get_state_for_events(
[membership_event_id], state_filter=state_filter
)
- room_state = room_state[membership_event_id]
+ room_state = room_state_events[membership_event_id]
now = self.clock.time_msec()
events = await self._event_serializer.serialize_events(
|