summary refs log tree commit diff
path: root/synapse/handlers/room_member.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/room_member.py')
-rw-r--r--synapse/handlers/room_member.py71
1 files changed, 51 insertions, 20 deletions
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 65b9a655d4..9c0fdeca15 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -16,7 +16,7 @@ import abc
 import logging
 import random
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set, Tuple
 
 from synapse import types
 from synapse.api.constants import (
@@ -410,11 +410,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             historical=historical,
         )
 
-        prev_state_ids = await context.get_prev_state_ids(
-            StateFilter.from_types([(EventTypes.Member, None)])
+        prev_member_event_ids = await context.get_prev_state_ids(
+            StateFilter.from_types([(EventTypes.Member, user_id)])
         )
 
-        prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
+        prev_member_event_id = prev_member_event_ids.get(
+            (EventTypes.Member, user_id), None
+        )
 
         if event.membership == Membership.JOIN:
             newly_joined = True
@@ -790,14 +792,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
 
-        state_before_join = await self.state_handler.compute_state_after_events(
-            room_id, latest_event_ids
+        old_membership_state_ids = await self.state_handler.compute_state_after_events(
+            room_id,
+            event_ids=latest_event_ids,
+            state_filter=StateFilter.from_types(
+                [(EventTypes.Member, target.to_string())]
+            ),
         )
 
         # TODO: Refactor into dictionary of explicitly allowed transitions
         # between old and new state, with specific error messages for some
         # transitions and generic otherwise
-        old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
+        old_state_id = old_membership_state_ids.get(
+            (EventTypes.Member, target.to_string())
+        )
         if old_state_id:
             old_state = await self.store.get_event(old_state_id, allow_none=True)
             old_membership = old_state.content.get("membership") if old_state else None
@@ -848,11 +856,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if action == "kick":
                 raise AuthError(403, "The target user is not in the room")
 
-        is_host_in_room = await self._is_host_in_room(state_before_join)
+        is_host_in_room = await self.store.is_host_joined(room_id, self._server_name)
 
         if effective_membership_state == Membership.JOIN:
             if requester.is_guest:
-                guest_can_join = await self._can_guest_join(state_before_join)
+                guest_access_ids = await self.state_handler.compute_state_after_events(
+                    room_id,
+                    event_ids=latest_event_ids,
+                    state_filter=StateFilter.from_types([(EventTypes.GuestAccess, "")]),
+                )
+                guest_can_join = await self._can_guest_join(guest_access_ids)
                 if not guest_can_join:
                     # This should be an auth check, but guests are a local concept,
                     # so don't really fit into the general auth process.
@@ -895,7 +908,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 remote_room_hosts,
                 content,
                 is_host_in_room,
-                state_before_join,
+                latest_event_ids,
             )
             if remote_join:
                 if ratelimit:
@@ -1040,7 +1053,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         remote_room_hosts: List[str],
         content: JsonDict,
         is_host_in_room: bool,
-        state_before_join: StateMap[str],
+        latest_event_ids: Collection[str],
     ) -> Tuple[bool, List[str]]:
         """
         Check whether the server should do a remote join (as opposed to a local
@@ -1060,8 +1073,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             content: The content to use as the event body of the join. This may
                 be modified.
             is_host_in_room: True if the host is in the room.
-            state_before_join: The state before the join event (i.e. the resolution of
-                the states after its parent events).
+            latest_event_ids: The parent events of the join event.
 
         Returns:
             A tuple of:
@@ -1079,16 +1091,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # for restricted join rules, a remote join must be used.
         room_version = await self.store.get_room_version(room_id)
 
+        # Only fetch the state that we need to check if we need to worry about
+        # restricted join rules.
+        partial_state_ids = await self.state_handler.compute_state_after_events(
+            room_id,
+            latest_event_ids,
+            StateFilter.from_types(
+                [(EventTypes.JoinRules, ""), (EventTypes.Member, user_id)]
+            ),
+        )
+
         # If restricted join rules are not being used, a local join can always
         # be used.
         if not await self.event_auth_handler.has_restricted_join_rules(
-            state_before_join, room_version
+            partial_state_ids, room_version
         ):
             return False, []
 
         # If the user is invited to the room or already joined, the join
         # event can always be issued locally.
-        prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
+        prev_member_event_id = partial_state_ids.get((EventTypes.Member, user_id), None)
         prev_member_event = None
         if prev_member_event_id:
             prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -1098,15 +1120,22 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             ):
                 return False, []
 
+        # Now we need to inspect the full membership, so pull that from the DB.
+        members_before_join = await self.state_handler.compute_state_after_events(
+            room_id,
+            latest_event_ids,
+            StateFilter.from_types([(EventTypes.Member, None)]),
+        )
+
         # If the local host has a user who can issue invites, then a local
         # join can be done.
         #
         # If not, generate a new list of remote hosts based on which
         # can issue invites.
-        event_map = await self.store.get_events(state_before_join.values())
+        event_map = await self.store.get_events(members_before_join.values())
         current_state = {
             state_key: event_map[event_id]
-            for state_key, event_id in state_before_join.items()
+            for state_key, event_id in members_before_join.items()
         }
         allowed_servers = get_servers_from_users(
             get_users_which_can_issue_invite(current_state)
@@ -1120,7 +1149,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Ensure the member should be allowed access via membership in a room.
         await self.event_auth_handler.check_restricted_join_rules(
-            state_before_join, room_version, user_id, prev_member_event
+            members_before_join, room_version, user_id, prev_member_event
         )
 
         # If this is going to be a local join, additional information must
@@ -1130,7 +1159,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             EventContentFields.AUTHORISING_USER
         ] = await self.event_auth_handler.get_user_which_could_invite(
             room_id,
-            state_before_join,
+            members_before_join,
         )
 
         return False, []
@@ -1236,7 +1265,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             requester = types.create_requester(target_user)
 
         prev_state_ids = await context.get_prev_state_ids(
-            StateFilter.from_types([(EventTypes.GuestAccess, None)])
+            StateFilter.from_types(
+                [(EventTypes.GuestAccess, None), (EventTypes.Member, event.state_key)]
+            )
         )
         if event.membership == Membership.JOIN:
             if requester.is_guest: