summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/events/builder.py2
-rw-r--r--synapse/handlers/room_member.py35
-rw-r--r--synapse/state/__init__.py21
3 files changed, 34 insertions, 24 deletions
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 98c203ada0..4caf6cbdee 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -120,7 +120,7 @@ class EventBuilder:
             The signed and hashed event.
         """
         if auth_event_ids is None:
-            state_ids = await self._state.get_current_state_ids(
+            state_ids = await self._state.compute_state_after_events(
                 self.room_id, prev_event_ids
             )
             auth_event_ids = self._event_auth_handler.compute_auth_events(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 90e0b21600..a5b9ac904e 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -755,14 +755,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
 
-        current_state_ids = await self.state_handler.get_current_state_ids(
-            room_id, latest_event_ids=latest_event_ids
+        state_before_join = await self.state_handler.compute_state_after_events(
+            room_id, latest_event_ids
         )
 
         # TODO: Refactor into dictionary of explicitly allowed transitions
         # between old and new state, with specific error messages for some
         # transitions and generic otherwise
-        old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
+        old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
         if old_state_id:
             old_state = await self.store.get_event(old_state_id, allow_none=True)
             old_membership = old_state.content.get("membership") if old_state else None
@@ -813,11 +813,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if action == "kick":
                 raise AuthError(403, "The target user is not in the room")
 
-        is_host_in_room = await self._is_host_in_room(current_state_ids)
+        is_host_in_room = await self._is_host_in_room(state_before_join)
 
         if effective_membership_state == Membership.JOIN:
             if requester.is_guest:
-                guest_can_join = await self._can_guest_join(current_state_ids)
+                guest_can_join = await self._can_guest_join(state_before_join)
                 if not guest_can_join:
                     # This should be an auth check, but guests are a local concept,
                     # so don't really fit into the general auth process.
@@ -855,7 +855,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
             # Check if a remote join should be performed.
             remote_join, remote_room_hosts = await self._should_perform_remote_join(
-                target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
+                target.to_string(),
+                room_id,
+                remote_room_hosts,
+                content,
+                is_host_in_room,
+                state_before_join,
             )
             if remote_join:
                 if ratelimit:
@@ -995,6 +1000,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         remote_room_hosts: List[str],
         content: JsonDict,
         is_host_in_room: bool,
+        state_before_join: StateMap[str],
     ) -> Tuple[bool, List[str]]:
         """
         Check whether the server should do a remote join (as opposed to a local
@@ -1014,6 +1020,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             content: The content to use as the event body of the join. This may
                 be modified.
             is_host_in_room: True if the host is in the room.
+            state_before_join: The state before the join event (i.e. the resolution of
+                the states after its parent events).
 
         Returns:
             A tuple of:
@@ -1030,20 +1038,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # If the host is in the room, but not one of the authorised hosts
         # for restricted join rules, a remote join must be used.
         room_version = await self.store.get_room_version(room_id)
-        current_state_ids = await self._storage_controllers.state.get_current_state_ids(
-            room_id
-        )
 
         # If restricted join rules are not being used, a local join can always
         # be used.
         if not await self.event_auth_handler.has_restricted_join_rules(
-            current_state_ids, room_version
+            state_before_join, room_version
         ):
             return False, []
 
         # If the user is invited to the room or already joined, the join
         # event can always be issued locally.
-        prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
+        prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
         prev_member_event = None
         if prev_member_event_id:
             prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -1058,10 +1063,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         #
         # If not, generate a new list of remote hosts based on which
         # can issue invites.
-        event_map = await self.store.get_events(current_state_ids.values())
+        event_map = await self.store.get_events(state_before_join.values())
         current_state = {
             state_key: event_map[event_id]
-            for state_key, event_id in current_state_ids.items()
+            for state_key, event_id in state_before_join.items()
         }
         allowed_servers = get_servers_from_users(
             get_users_which_can_issue_invite(current_state)
@@ -1075,7 +1080,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Ensure the member should be allowed access via membership in a room.
         await self.event_auth_handler.check_restricted_join_rules(
-            current_state_ids, room_version, user_id, prev_member_event
+            state_before_join, room_version, user_id, prev_member_event
         )
 
         # If this is going to be a local join, additional information must
@@ -1085,7 +1090,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             EventContentFields.AUTHORISING_USER
         ] = await self.event_auth_handler.get_user_which_could_invite(
             room_id,
-            current_state_ids,
+            state_before_join,
         )
 
         return False, []
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 3a65bd0849..56606e9afb 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -153,22 +153,27 @@ class StateHandler:
             ReplicationUpdateCurrentStateRestServlet.make_client(hs)
         )
 
-    async def get_current_state_ids(
+    async def compute_state_after_events(
         self,
         room_id: str,
-        latest_event_ids: Collection[str],
+        event_ids: Collection[str],
     ) -> StateMap[str]:
-        """Get the current state, or the state at a set of events, for a room
+        """Fetch the state after each of the given event IDs. Resolve them and return.
+
+        This is typically used where `event_ids` is a collection of forward extremities
+        in a room, intended to become the `prev_events` of a new event E. If so, the
+        return value of this function represents the state before E.
 
         Args:
-            room_id:
-            latest_event_ids: The forward extremities to resolve.
+            room_id: the room_id containing the given events.
+            event_ids: the events whose state should be fetched and resolved.
 
         Returns:
-            the state dict, mapping from (event_type, state_key) -> event_id
+            the state dict (a mapping from (event_type, state_key) -> event_id) which
+            holds the resolution of the states after the given event IDs.
         """
-        logger.debug("calling resolve_state_groups from get_current_state_ids")
-        ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
+        logger.debug("calling resolve_state_groups from compute_state_after_events")
+        ret = await self.resolve_state_groups_for_events(room_id, event_ids)
         return await ret.get_state(self._state_storage_controller, StateFilter.all())
 
     async def get_current_users_in_room(