summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/event_auth.py43
1 files changed, 34 insertions, 9 deletions
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index d922c8dc35..c8b06f760e 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -23,7 +23,20 @@
 import collections.abc
 import logging
 import typing
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
+from typing import (
+    Any,
+    ChainMap,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    MutableMapping,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+    cast,
+)
 
 from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
@@ -175,12 +188,22 @@ async def check_state_independent_auth_rules(
         return
 
     # 2. Reject if event has auth_events that: ...
+    auth_events: ChainMap[str, EventBase] = ChainMap()
     if batched_auth_events:
-        # Copy the batched auth events to avoid mutating them.
-        auth_events = dict(batched_auth_events)
-        needed_auth_event_ids = set(event.auth_event_ids()) - batched_auth_events.keys()
+        # batched_auth_events can become very large. To avoid repeatedly copying it, which
+        # would significantly impact performance, we use a ChainMap.
+        # batched_auth_events must be cast to MutableMapping because .new_child() requires
+        # this type. This casting is safe as the mapping is never mutated.
+        auth_events = auth_events.new_child(
+            cast(MutableMapping[str, "EventBase"], batched_auth_events)
+        )
+        needed_auth_event_ids = [
+            event_id
+            for event_id in event.auth_event_ids()
+            if event_id not in batched_auth_events
+        ]
         if needed_auth_event_ids:
-            auth_events.update(
+            auth_events = auth_events.new_child(
                 await store.get_events(
                     needed_auth_event_ids,
                     redact_behaviour=EventRedactBehaviour.as_is,
@@ -188,10 +211,12 @@ async def check_state_independent_auth_rules(
                 )
             )
     else:
-        auth_events = await store.get_events(
-            event.auth_event_ids(),
-            redact_behaviour=EventRedactBehaviour.as_is,
-            allow_rejected=True,
+        auth_events = auth_events.new_child(
+            await store.get_events(
+                event.auth_event_ids(),
+                redact_behaviour=EventRedactBehaviour.as_is,
+                allow_rejected=True,
+            )
         )
 
     room_id = event.room_id