summary refs log tree commit diff
path: root/synapse/handlers/message.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/message.py')
-rw-r--r--synapse/handlers/message.py50
1 files changed, 42 insertions, 8 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index cb1bc4c06f..7ca126dbd1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -55,7 +55,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
+from synapse.types import (
+    MutableStateMap,
+    Requester,
+    RoomAlias,
+    StreamToken,
+    UserID,
+    create_requester,
+)
 from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, gather_results
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -78,7 +85,7 @@ class MessageHandler:
         self.state = hs.get_state_handler()
         self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
-        self.state_store = self.storage.state
+        self.state_storage = self.storage.state
         self._event_serializer = hs.get_event_client_serializer()
         self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
 
@@ -125,7 +132,7 @@ class MessageHandler:
             assert (
                 membership_event_id is not None
             ), "check_user_in_room_or_world_readable returned invalid data"
-            room_state = await self.state_store.get_state_for_events(
+            room_state = await self.state_storage.get_state_for_events(
                 [membership_event_id], StateFilter.from_types([key])
             )
             data = room_state[membership_event_id].get(key)
@@ -186,7 +193,7 @@ class MessageHandler:
 
             # check whether the user is in the room at that time to determine
             # whether they should be treated as peeking.
-            state_map = await self.state_store.get_state_for_event(
+            state_map = await self.state_storage.get_state_for_event(
                 last_event.event_id,
                 StateFilter.from_types([(EventTypes.Member, user_id)]),
             )
@@ -207,7 +214,7 @@ class MessageHandler:
             )
 
             if visible_events:
-                room_state_events = await self.state_store.get_state_for_events(
+                room_state_events = await self.state_storage.get_state_for_events(
                     [last_event.event_id], state_filter=state_filter
                 )
                 room_state: Mapping[Any, EventBase] = room_state_events[
@@ -237,7 +244,7 @@ class MessageHandler:
                 assert (
                     membership_event_id is not None
                 ), "check_user_in_room_or_world_readable returned invalid data"
-                room_state_events = await self.state_store.get_state_for_events(
+                room_state_events = await self.state_storage.get_state_for_events(
                     [membership_event_id], state_filter=state_filter
                 )
                 room_state = room_state_events[membership_event_id]
@@ -1022,8 +1029,35 @@ class EventCreationHandler:
             #
             # TODO(faster_joins): figure out how this works, and make sure that the
             #   old state is complete.
-            old_state = await self.store.get_events_as_list(state_event_ids)
-            context = await self.state.compute_event_context(event, old_state=old_state)
+            metadata = await self.store.get_metadata_for_events(state_event_ids)
+
+            state_map_for_event: MutableStateMap[str] = {}
+            for state_id in state_event_ids:
+                data = metadata.get(state_id)
+                if data is None:
+                    # We're trying to persist a new historical batch of events
+                    # with the given state, e.g. via
+                    # `RoomBatchSendEventRestServlet`. The state can be inferred
+                    # by Synapse or set directly by the client.
+                    #
+                    # Either way, we should have persisted all the state before
+                    # getting here.
+                    raise Exception(
+                        f"State event {state_id} not found in DB,"
+                        " Synapse should have persisted it before using it."
+                    )
+
+                if data.state_key is None:
+                    raise Exception(
+                        f"Trying to set non-state event {state_id} as state"
+                    )
+
+                state_map_for_event[(data.event_type, data.state_key)] = state_id
+
+            context = await self.state.compute_event_context(
+                event,
+                state_ids_before_event=state_map_for_event,
+            )
         else:
             context = await self.state.compute_event_context(event)