summary refs log tree commit diff
path: root/synapse/handlers/sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sync.py')
-rw-r--r--synapse/handlers/sync.py27
1 files changed, 19 insertions, 8 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b5859dcb28..a1d41358d9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -621,21 +621,32 @@ class SyncHandler:
         )
 
     async def get_state_after_event(
-        self, event: EventBase, state_filter: Optional[StateFilter] = None
+        self, event_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[str]:
         """
         Get the room state after the given event
 
         Args:
-            event: event of interest
+            event_id: event of interest
             state_filter: The state filter used to fetch state from the database.
         """
         state_ids = await self._state_storage_controller.get_state_ids_for_event(
-            event.event_id, state_filter=state_filter or StateFilter.all()
+            event_id, state_filter=state_filter or StateFilter.all()
         )
-        if event.is_state():
+
+        # using get_metadata_for_events here (instead of get_event) sidesteps an issue
+        # with redactions: if `event_id` is a redaction event, and we don't have the
+        # original (possibly because it got purged), get_event will refuse to return
+        # the redaction event, which isn't terribly helpful here.
+        #
+        # (To be fair, in that case we could assume it's *not* a state event, and
+        # therefore we don't need to worry about it. But still, it seems cleaner just
+        # to pull the metadata.)
+        m = (await self.store.get_metadata_for_events([event_id]))[event_id]
+        if m.state_key is not None and m.rejection_reason is None:
             state_ids = dict(state_ids)
-            state_ids[(event.type, event.state_key)] = event.event_id
+            state_ids[(m.event_type, m.state_key)] = event_id
+
         return state_ids
 
     async def get_state_at(
@@ -654,14 +665,14 @@ class SyncHandler:
         # FIXME: This gets the state at the latest event before the stream ordering,
         # which might not be the same as the "current state" of the room at the time
         # of the stream token if there were multiple forward extremities at the time.
-        last_event = await self.store.get_last_event_in_room_before_stream_ordering(
+        last_event_id = await self.store.get_last_event_in_room_before_stream_ordering(
             room_id,
             end_token=stream_position.room_key,
         )
 
-        if last_event:
+        if last_event_id:
             state = await self.get_state_after_event(
-                last_event, state_filter=state_filter or StateFilter.all()
+                last_event_id, state_filter=state_filter or StateFilter.all()
             )
 
         else: