summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2023-03-01 17:33:04 +0000
committerErik Johnston <erik@matrix.org>2023-03-01 17:33:04 +0000
commit2ab39ac4976b7e063a659b300ede24add5e1dce3 (patch)
treeb95aee67bf4b7188fc1e72352461bc5e124f993a
parentRemove support for aggregating reactions (#15172) (diff)
downloadsynapse-2ab39ac4976b7e063a659b300ede24add5e1dce3.tar.xz
-rw-r--r--synapse/event_auth.py20
-rw-r--r--synapse/events/snapshot.py37
-rw-r--r--synapse/handlers/event_auth.py14
-rw-r--r--synapse/handlers/room.py2
4 files changed, 44 insertions, 29 deletions
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 4d6d1b8ebd..813ef00ceb 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -168,13 +168,21 @@ async def check_state_independent_auth_rules(
         return
 
     # 2. Reject if event has auth_events that: ...
-    auth_events = await store.get_events(
-        event.auth_event_ids(),
-        redact_behaviour=EventRedactBehaviour.as_is,
-        allow_rejected=True,
-    )
     if batched_auth_events:
-        auth_events.update(batched_auth_events)
+        auth_event_ids = event.auth_event_ids()
+        auth_events = dict(batched_auth_events)
+        if set(auth_event_ids) - batched_auth_events.keys():
+            auth_events.update(
+                await store.get_events(
+                    set(auth_event_ids) - batched_auth_events.keys()
+                )
+            )
+    else:
+        auth_events = await store.get_events(
+            event.auth_event_ids(),
+            redact_behaviour=EventRedactBehaviour.as_is,
+            allow_rejected=True,
+        )
 
     room_id = event.room_id
     auth_dict: MutableStateMap[str] = {}
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a91a5d1e3c..1bb3d8f476 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -135,6 +135,8 @@ class EventContext(UnpersistedEventContextBase):
     delta_ids: Optional[StateMap[str]] = None
     app_service: Optional[ApplicationService] = None
 
+    _state_map_before_event: Optional[StateMap[str]] = None
+
     partial_state: bool = False
 
     @staticmethod
@@ -293,6 +295,11 @@ class EventContext(UnpersistedEventContextBase):
             Maps a (type, state_key) to the event ID of the state event matching
             this tuple.
         """
+        if self._state_map_before_event is not None:
+            if state_filter is not None:
+                return state_filter.filter_state(self._state_map_before_event)
+            return self._state_map_before_event
+
         assert self.state_group_before_event is not None
         return await self._storage.state.get_state_ids_for_group(
             self.state_group_before_event, state_filter
@@ -374,26 +381,16 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
 
         events_and_persisted_context = []
         for event, unpersisted_context in amended_events_and_context:
-            if event.is_state():
-                context = EventContext(
-                    storage=unpersisted_context._storage,
-                    state_group=unpersisted_context.state_group_after_event,
-                    state_group_before_event=unpersisted_context.state_group_before_event,
-                    state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
-                    partial_state=unpersisted_context.partial_state,
-                    prev_group=unpersisted_context.state_group_before_event,
-                    delta_ids=unpersisted_context.state_delta_due_to_event,
-                )
-            else:
-                context = EventContext(
-                    storage=unpersisted_context._storage,
-                    state_group=unpersisted_context.state_group_after_event,
-                    state_group_before_event=unpersisted_context.state_group_before_event,
-                    state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
-                    partial_state=unpersisted_context.partial_state,
-                    prev_group=unpersisted_context.prev_group_for_state_group_before_event,
-                    delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
-                )
+            context = EventContext(
+                storage=unpersisted_context._storage,
+                state_group=unpersisted_context.state_group_after_event,
+                state_group_before_event=unpersisted_context.state_group_before_event,
+                state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+                partial_state=unpersisted_context.partial_state,
+                prev_group=unpersisted_context.prev_group_for_state_group_before_event,
+                delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
+                state_map_before_event=unpersisted_context.state_map_before_event,
+            )
             events_and_persisted_context.append((event, context))
         return events_and_persisted_context
 
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index c508861b6a..272f312fda 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -63,9 +63,19 @@ class EventAuthHandler:
             self._store, event, batched_auth_events
         )
         auth_event_ids = event.auth_event_ids()
-        auth_events_by_id = await self._store.get_events(auth_event_ids)
+        logger.info("Batched auth events %s", list(batched_auth_events.keys()))
+        logger.info("auth events %s", auth_event_ids)
         if batched_auth_events:
-            auth_events_by_id.update(batched_auth_events)
+            auth_events_by_id = dict(batched_auth_events)
+            if set(auth_event_ids) - set(batched_auth_events):
+                logger.info("fetching %s", set(auth_event_ids) - set(batched_auth_events))
+                auth_events_by_id.update(
+                    await self._store.get_events(
+                        set(auth_event_ids) - set(batched_auth_events)
+                    )
+                )
+        else:
+            auth_events_by_id = await self._store.get_events(auth_event_ids)
         check_state_dependent_auth_rules(event, auth_events_by_id.values())
 
     def compute_auth_events(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b1784638f4..c018d24e2e 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1123,7 +1123,7 @@ class RoomCreationHandler:
                 event_dict,
                 prev_event_ids=prev_event,
                 depth=depth,
-                state_map=state_map,
+                state_map=dict(state_map),
                 for_batch=for_batch,
             )