diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a23a8ce2a1..0db0bd7304 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -63,9 +63,18 @@ class EventAuthHandler:
self._store, event, batched_auth_events
)
auth_event_ids = event.auth_event_ids()
- auth_events_by_id = await self._store.get_events(auth_event_ids)
+
if batched_auth_events:
- auth_events_by_id.update(batched_auth_events)
+ # Copy the batched auth events to avoid mutating them.
+ auth_events_by_id = dict(batched_auth_events)
+ needed_auth_event_ids = set(auth_event_ids) - set(batched_auth_events)
+ if needed_auth_event_ids:
+ auth_events_by_id.update(
+ await self._store.get_events(needed_auth_event_ids)
+ )
+ else:
+ auth_events_by_id = await self._store.get_events(auth_event_ids)
+
check_state_dependent_auth_rules(event, auth_events_by_id.values())
def compute_auth_events(
@@ -202,7 +211,7 @@ class EventAuthHandler:
state_ids: StateMap[str],
room_version: RoomVersion,
user_id: str,
- prev_member_event: Optional[EventBase],
+ prev_membership: Optional[str],
) -> None:
"""
Check whether a user can join a room without an invite due to restricted join rules.
@@ -214,15 +223,14 @@ class EventAuthHandler:
state_ids: The state of the room as it currently is.
room_version: The room version of the room being joined.
user_id: The user joining the room.
- prev_member_event: The current membership event for this user.
+ prev_membership: The current membership state for this user. `None` if the
+ user has never joined the room (equivalent to "leave").
Raises:
AuthError if the user cannot join the room.
"""
# If the member is invited or currently joined, then nothing to do.
- if prev_member_event and (
- prev_member_event.membership in (Membership.JOIN, Membership.INVITE)
- ):
+ if prev_membership in (Membership.JOIN, Membership.INVITE):
return
# This is not a room with a restricted join rule, so we don't need to do the
@@ -237,7 +245,6 @@ class EventAuthHandler:
# in any of them.
allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
if not await self.is_user_in_rooms(allowed_rooms, user_id):
-
# If this is a remote request, the user might be in an allowed room
# that we do not know about.
if get_domain_from_id(user_id) != self._server_name:
@@ -255,13 +262,14 @@ class EventAuthHandler:
)
async def has_restricted_join_rules(
- self, state_ids: StateMap[str], room_version: RoomVersion
+ self, partial_state_ids: StateMap[str], room_version: RoomVersion
) -> bool:
"""
Return if the room has the proper join rules set for access via rooms.
Args:
- state_ids: The state of the room as it currently is.
+ state_ids: The state of the room as it currently is. May be full or partial
+ state.
room_version: The room version of the room to query.
Returns:
@@ -272,7 +280,7 @@ class EventAuthHandler:
return False
# If there's no join rule, then it defaults to invite (so this doesn't apply).
- join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+ join_rules_event_id = partial_state_ids.get((EventTypes.JoinRules, ""), None)
if not join_rules_event_id:
return False
|