diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 8249ca1ed2..3bbad0271b 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, List, Optional, Union
+from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union
from synapse import event_auth
from synapse.api.constants import (
@@ -29,7 +29,6 @@ from synapse.event_auth import (
)
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
from synapse.types import StateMap, get_domain_from_id
if TYPE_CHECKING:
@@ -51,12 +50,21 @@ class EventAuthHandler:
async def check_auth_rules_from_context(
self,
event: EventBase,
- context: EventContext,
+ batched_auth_events: Optional[Mapping[str, EventBase]] = None,
) -> None:
- """Check an event passes the auth rules at its own auth events"""
- await check_state_independent_auth_rules(self._store, event)
+ """Check an event passes the auth rules at its own auth events
+ Args:
+ event: event to be authed
+ batched_auth_events: if the event being authed is part of a batch, any events
+ from the same batch that may be necessary to auth the current event
+ """
+ await check_state_independent_auth_rules(
+ self._store, event, batched_auth_events
+ )
auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids)
+ if batched_auth_events:
+ auth_events_by_id.update(batched_auth_events)
check_state_dependent_auth_rules(event, auth_events_by_id.values())
def compute_auth_events(
|