diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 4d6d1b8ebd..813ef00ceb 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -168,13 +168,21 @@ async def check_state_independent_auth_rules(
return
# 2. Reject if event has auth_events that: ...
- auth_events = await store.get_events(
- event.auth_event_ids(),
- redact_behaviour=EventRedactBehaviour.as_is,
- allow_rejected=True,
- )
if batched_auth_events:
- auth_events.update(batched_auth_events)
+ auth_event_ids = event.auth_event_ids()
+ auth_events = dict(batched_auth_events)
+ if set(auth_event_ids) - batched_auth_events.keys():
+ auth_events.update(
+ await store.get_events(
+ set(auth_event_ids) - batched_auth_events.keys()
+ )
+ )
+ else:
+ auth_events = await store.get_events(
+ event.auth_event_ids(),
+ redact_behaviour=EventRedactBehaviour.as_is,
+ allow_rejected=True,
+ )
room_id = event.room_id
auth_dict: MutableStateMap[str] = {}
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a91a5d1e3c..1bb3d8f476 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -135,6 +135,8 @@ class EventContext(UnpersistedEventContextBase):
delta_ids: Optional[StateMap[str]] = None
app_service: Optional[ApplicationService] = None
+ _state_map_before_event: Optional[StateMap[str]] = None
+
partial_state: bool = False
@staticmethod
@@ -293,6 +295,11 @@ class EventContext(UnpersistedEventContextBase):
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
+ if self._state_map_before_event is not None:
+ if state_filter is not None:
+ return state_filter.filter_state(self._state_map_before_event)
+ return self._state_map_before_event
+
assert self.state_group_before_event is not None
return await self._storage.state.get_state_ids_for_group(
self.state_group_before_event, state_filter
@@ -374,26 +381,16 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
events_and_persisted_context = []
for event, unpersisted_context in amended_events_and_context:
- if event.is_state():
- context = EventContext(
- storage=unpersisted_context._storage,
- state_group=unpersisted_context.state_group_after_event,
- state_group_before_event=unpersisted_context.state_group_before_event,
- state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
- partial_state=unpersisted_context.partial_state,
- prev_group=unpersisted_context.state_group_before_event,
- delta_ids=unpersisted_context.state_delta_due_to_event,
- )
- else:
- context = EventContext(
- storage=unpersisted_context._storage,
- state_group=unpersisted_context.state_group_after_event,
- state_group_before_event=unpersisted_context.state_group_before_event,
- state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
- partial_state=unpersisted_context.partial_state,
- prev_group=unpersisted_context.prev_group_for_state_group_before_event,
- delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
- )
+ context = EventContext(
+ storage=unpersisted_context._storage,
+ state_group=unpersisted_context.state_group_after_event,
+ state_group_before_event=unpersisted_context.state_group_before_event,
+ state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+ partial_state=unpersisted_context.partial_state,
+ prev_group=unpersisted_context.prev_group_for_state_group_before_event,
+ delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
+ state_map_before_event=unpersisted_context.state_map_before_event,
+ )
events_and_persisted_context.append((event, context))
return events_and_persisted_context
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index c508861b6a..272f312fda 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -63,9 +63,19 @@ class EventAuthHandler:
self._store, event, batched_auth_events
)
auth_event_ids = event.auth_event_ids()
- auth_events_by_id = await self._store.get_events(auth_event_ids)
+ logger.info("Batched auth events %s", list(batched_auth_events.keys()))
+ logger.info("auth events %s", auth_event_ids)
if batched_auth_events:
- auth_events_by_id.update(batched_auth_events)
+ auth_events_by_id = dict(batched_auth_events)
+ if set(auth_event_ids) - set(batched_auth_events):
+ logger.info("fetching %s", set(auth_event_ids) - set(batched_auth_events))
+ auth_events_by_id.update(
+ await self._store.get_events(
+ set(auth_event_ids) - set(batched_auth_events)
+ )
+ )
+ else:
+ auth_events_by_id = await self._store.get_events(auth_event_ids)
check_state_dependent_auth_rules(event, auth_events_by_id.values())
def compute_auth_events(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b1784638f4..c018d24e2e 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1123,7 +1123,7 @@ class RoomCreationHandler:
event_dict,
prev_event_ids=prev_event,
depth=depth,
- state_map=state_map,
+ state_map=dict(state_map),
for_batch=for_batch,
)
|