diff --git a/changelog.d/13283.misc b/changelog.d/13283.misc
new file mode 100644
index 0000000000..ea2510ca53
--- /dev/null
+++ b/changelog.d/13283.misc
@@ -0,0 +1 @@
+Don't fetch the full state during membership changes.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e151962055..4bb4d09d4a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -856,13 +856,20 @@ class FederationHandler:
# Note that this requires the /send_join request to come back to the
# same server.
if room_version.msc3083_join_rules:
- state_ids = await self._state_storage_controller.get_current_state_ids(
- room_id
+ partial_state_ids = (
+ await self._state_storage_controller.get_current_state_ids(
+ room_id,
+ state_filter=StateFilter.from_types(
+ [(EventTypes.JoinRules, ""), (EventTypes.Member, user_id)]
+ ),
+ )
)
if await self._event_auth_handler.has_restricted_join_rules(
- state_ids, room_version
+ partial_state_ids, room_version
):
- prev_member_event_id = state_ids.get((EventTypes.Member, user_id), None)
+ prev_member_event_id = partial_state_ids.get(
+ (EventTypes.Member, user_id), None
+ )
# If the user is invited or joined to the room already, then
# no additional info is needed.
include_auth_user_id = True
@@ -874,6 +881,12 @@ class FederationHandler:
)
if include_auth_user_id:
+ state_ids = (
+ await self._state_storage_controller.get_current_state_ids(
+ room_id,
+ )
+ )
+
event_content[
EventContentFields.AUTHORISING_USER
] = await self._event_auth_handler.get_user_which_could_invite(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 65b9a655d4..9c0fdeca15 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -16,7 +16,7 @@ import abc
import logging
import random
from http import HTTPStatus
-from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set, Tuple
from synapse import types
from synapse.api.constants import (
@@ -410,11 +410,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical=historical,
)
- prev_state_ids = await context.get_prev_state_ids(
- StateFilter.from_types([(EventTypes.Member, None)])
+ prev_member_event_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.Member, user_id)])
)
- prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
+ prev_member_event_id = prev_member_event_ids.get(
+ (EventTypes.Member, user_id), None
+ )
if event.membership == Membership.JOIN:
newly_joined = True
@@ -790,14 +792,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
- state_before_join = await self.state_handler.compute_state_after_events(
- room_id, latest_event_ids
+ old_membership_state_ids = await self.state_handler.compute_state_after_events(
+ room_id,
+ event_ids=latest_event_ids,
+ state_filter=StateFilter.from_types(
+ [(EventTypes.Member, target.to_string())]
+ ),
)
# TODO: Refactor into dictionary of explicitly allowed transitions
# between old and new state, with specific error messages for some
# transitions and generic otherwise
- old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
+ old_state_id = old_membership_state_ids.get(
+ (EventTypes.Member, target.to_string())
+ )
if old_state_id:
old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
@@ -848,11 +856,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
- is_host_in_room = await self._is_host_in_room(state_before_join)
+ is_host_in_room = await self.store.is_host_joined(room_id, self._server_name)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
- guest_can_join = await self._can_guest_join(state_before_join)
+ guest_access_ids = await self.state_handler.compute_state_after_events(
+ room_id,
+ event_ids=latest_event_ids,
+ state_filter=StateFilter.from_types([(EventTypes.GuestAccess, "")]),
+ )
+ guest_can_join = await self._can_guest_join(guest_access_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
@@ -895,7 +908,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
remote_room_hosts,
content,
is_host_in_room,
- state_before_join,
+ latest_event_ids,
)
if remote_join:
if ratelimit:
@@ -1040,7 +1053,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
remote_room_hosts: List[str],
content: JsonDict,
is_host_in_room: bool,
- state_before_join: StateMap[str],
+ latest_event_ids: Collection[str],
) -> Tuple[bool, List[str]]:
"""
Check whether the server should do a remote join (as opposed to a local
@@ -1060,8 +1073,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content: The content to use as the event body of the join. This may
be modified.
is_host_in_room: True if the host is in the room.
- state_before_join: The state before the join event (i.e. the resolution of
- the states after its parent events).
+ latest_event_ids: The parent events of the join event.
Returns:
A tuple of:
@@ -1079,16 +1091,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
+ # Only fetch the state that we need to check if we need to worry about
+ # restricted join rules.
+ partial_state_ids = await self.state_handler.compute_state_after_events(
+ room_id,
+ latest_event_ids,
+ StateFilter.from_types(
+ [(EventTypes.JoinRules, ""), (EventTypes.Member, user_id)]
+ ),
+ )
+
# If restricted join rules are not being used, a local join can always
# be used.
if not await self.event_auth_handler.has_restricted_join_rules(
- state_before_join, room_version
+ partial_state_ids, room_version
):
return False, []
# If the user is invited to the room or already joined, the join
# event can always be issued locally.
- prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
+ prev_member_event_id = partial_state_ids.get((EventTypes.Member, user_id), None)
prev_member_event = None
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -1098,15 +1120,22 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
):
return False, []
+ # Now we need to inspect the full membership, so pull that from the DB.
+ members_before_join = await self.state_handler.compute_state_after_events(
+ room_id,
+ latest_event_ids,
+ StateFilter.from_types([(EventTypes.Member, None)]),
+ )
+
# If the local host has a user who can issue invites, then a local
# join can be done.
#
# If not, generate a new list of remote hosts based on which
# can issue invites.
- event_map = await self.store.get_events(state_before_join.values())
+ event_map = await self.store.get_events(members_before_join.values())
current_state = {
state_key: event_map[event_id]
- for state_key, event_id in state_before_join.items()
+ for state_key, event_id in members_before_join.items()
}
allowed_servers = get_servers_from_users(
get_users_which_can_issue_invite(current_state)
@@ -1120,7 +1149,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Ensure the member should be allowed access via membership in a room.
await self.event_auth_handler.check_restricted_join_rules(
- state_before_join, room_version, user_id, prev_member_event
+ members_before_join, room_version, user_id, prev_member_event
)
# If this is going to be a local join, additional information must
@@ -1130,7 +1159,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
EventContentFields.AUTHORISING_USER
] = await self.event_auth_handler.get_user_which_could_invite(
room_id,
- state_before_join,
+ members_before_join,
)
return False, []
@@ -1236,7 +1265,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
requester = types.create_requester(target_user)
prev_state_ids = await context.get_prev_state_ids(
- StateFilter.from_types([(EventTypes.GuestAccess, None)])
+ StateFilter.from_types(
+ [(EventTypes.GuestAccess, None), (EventTypes.Member, event.state_key)]
+ )
)
if event.membership == Membership.JOIN:
if requester.is_guest:
|