diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/event_auth.py | 43 |
1 files changed, 34 insertions, 9 deletions
diff --git a/synapse/event_auth.py b/synapse/event_auth.py index d922c8dc35..c8b06f760e 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -23,7 +23,20 @@ import collections.abc import logging import typing -from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union +from typing import ( + Any, + ChainMap, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Set, + Tuple, + Union, + cast, +) from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -175,12 +188,22 @@ async def check_state_independent_auth_rules( return # 2. Reject if event has auth_events that: ... + auth_events: ChainMap[str, EventBase] = ChainMap() if batched_auth_events: - # Copy the batched auth events to avoid mutating them. - auth_events = dict(batched_auth_events) - needed_auth_event_ids = set(event.auth_event_ids()) - batched_auth_events.keys() + # batched_auth_events can become very large. To avoid repeatedly copying it, which + # would significantly impact performance, we use a ChainMap. + # batched_auth_events must be cast to MutableMapping because .new_child() requires + # this type. This casting is safe as the mapping is never mutated. + auth_events = auth_events.new_child( + cast(MutableMapping[str, "EventBase"], batched_auth_events) + ) + needed_auth_event_ids = [ + event_id + for event_id in event.auth_event_ids() + if event_id not in batched_auth_events + ] if needed_auth_event_ids: - auth_events.update( + auth_events = auth_events.new_child( await store.get_events( needed_auth_event_ids, redact_behaviour=EventRedactBehaviour.as_is, @@ -188,10 +211,12 @@ async def check_state_independent_auth_rules( ) ) else: - auth_events = await store.get_events( - event.auth_event_ids(), - redact_behaviour=EventRedactBehaviour.as_is, - allow_rejected=True, + auth_events = auth_events.new_child( + await store.get_events( + event.auth_event_ids(), + redact_behaviour=EventRedactBehaviour.as_is, + allow_rejected=True, + ) ) room_id = event.room_id |