summary refs log tree commit diff
path: root/synapse/handlers/event_auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/event_auth.py')
-rw-r--r--synapse/handlers/event_auth.py30
1 files changed, 19 insertions, 11 deletions
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a23a8ce2a1..0db0bd7304 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -63,9 +63,18 @@ class EventAuthHandler:
             self._store, event, batched_auth_events
         )
         auth_event_ids = event.auth_event_ids()
-        auth_events_by_id = await self._store.get_events(auth_event_ids)
+
         if batched_auth_events:
-            auth_events_by_id.update(batched_auth_events)
+            # Copy the batched auth events to avoid mutating them.
+            auth_events_by_id = dict(batched_auth_events)
+            needed_auth_event_ids = set(auth_event_ids) - set(batched_auth_events)
+            if needed_auth_event_ids:
+                auth_events_by_id.update(
+                    await self._store.get_events(needed_auth_event_ids)
+                )
+        else:
+            auth_events_by_id = await self._store.get_events(auth_event_ids)
+
         check_state_dependent_auth_rules(event, auth_events_by_id.values())
 
     def compute_auth_events(
@@ -202,7 +211,7 @@ class EventAuthHandler:
         state_ids: StateMap[str],
         room_version: RoomVersion,
         user_id: str,
-        prev_member_event: Optional[EventBase],
+        prev_membership: Optional[str],
     ) -> None:
         """
         Check whether a user can join a room without an invite due to restricted join rules.
@@ -214,15 +223,14 @@ class EventAuthHandler:
             state_ids: The state of the room as it currently is.
             room_version: The room version of the room being joined.
             user_id: The user joining the room.
-            prev_member_event: The current membership event for this user.
+            prev_membership: The current membership state for this user. `None` if the
+                user has never joined the room (equivalent to "leave").
 
         Raises:
             AuthError if the user cannot join the room.
         """
         # If the member is invited or currently joined, then nothing to do.
-        if prev_member_event and (
-            prev_member_event.membership in (Membership.JOIN, Membership.INVITE)
-        ):
+        if prev_membership in (Membership.JOIN, Membership.INVITE):
             return
 
         # This is not a room with a restricted join rule, so we don't need to do the
@@ -237,7 +245,6 @@ class EventAuthHandler:
         # in any of them.
         allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
         if not await self.is_user_in_rooms(allowed_rooms, user_id):
-
             # If this is a remote request, the user might be in an allowed room
             # that we do not know about.
             if get_domain_from_id(user_id) != self._server_name:
@@ -255,13 +262,14 @@ class EventAuthHandler:
             )
 
     async def has_restricted_join_rules(
-        self, state_ids: StateMap[str], room_version: RoomVersion
+        self, partial_state_ids: StateMap[str], room_version: RoomVersion
     ) -> bool:
         """
         Return if the room has the proper join rules set for access via rooms.
 
         Args:
-            state_ids: The state of the room as it currently is.
+            state_ids: The state of the room as it currently is. May be full or partial
+                state.
             room_version: The room version of the room to query.
 
         Returns:
@@ -272,7 +280,7 @@ class EventAuthHandler:
             return False
 
         # If there's no join rule, then it defaults to invite (so this doesn't apply).
-        join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+        join_rules_event_id = partial_state_ids.get((EventTypes.JoinRules, ""), None)
         if not join_rules_event_id:
             return False