diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index e0be9f88cc..af55874b5c 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -16,18 +16,7 @@
import collections.abc
import logging
import typing
-from typing import (
- Any,
- Collection,
- Dict,
- Iterable,
- List,
- Mapping,
- Optional,
- Set,
- Tuple,
- Union,
-)
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -56,7 +45,13 @@ from synapse.api.room_versions import (
RoomVersions,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import MutableStateMap, StateMap, UserID, get_domain_from_id
+from synapse.types import (
+ MutableStateMap,
+ StateMap,
+ StrCollection,
+ UserID,
+ get_domain_from_id,
+)
if typing.TYPE_CHECKING:
# conditional imports to avoid import cycle
@@ -69,7 +64,7 @@ logger = logging.getLogger(__name__)
class _EventSourceStore(Protocol):
async def get_events(
self,
- event_ids: Collection[str],
+ event_ids: StrCollection,
redact_behaviour: EventRedactBehaviour,
get_prev_content: bool = False,
allow_rejected: bool = False,
@@ -173,13 +168,24 @@ async def check_state_independent_auth_rules(
return
# 2. Reject if event has auth_events that: ...
- auth_events = await store.get_events(
- event.auth_event_ids(),
- redact_behaviour=EventRedactBehaviour.as_is,
- allow_rejected=True,
- )
if batched_auth_events:
- auth_events.update(batched_auth_events)
+ # Copy the batched auth events to avoid mutating them.
+ auth_events = dict(batched_auth_events)
+ needed_auth_event_ids = set(event.auth_event_ids()) - batched_auth_events.keys()
+ if needed_auth_event_ids:
+ auth_events.update(
+ await store.get_events(
+ needed_auth_event_ids,
+ redact_behaviour=EventRedactBehaviour.as_is,
+ allow_rejected=True,
+ )
+ )
+ else:
+ auth_events = await store.get_events(
+ event.auth_event_ids(),
+ redact_behaviour=EventRedactBehaviour.as_is,
+ allow_rejected=True,
+ )
room_id = event.room_id
auth_dict: MutableStateMap[str] = {}
|