diff --git a/changelog.d/10254.feature b/changelog.d/10254.feature
new file mode 100644
index 0000000000..df8bb51167
--- /dev/null
+++ b/changelog.d/10254.feature
@@ -0,0 +1 @@
+Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 054ab14ab6..dc662bca83 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -75,6 +75,9 @@ class Codes:
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
BAD_ALIAS = "M_BAD_ALIAS"
+ # For restricted join rules.
+ UNABLE_AUTHORISE_JOIN = "M_UNABLE_TO_AUTHORISE_JOIN"
+ UNABLE_TO_GRANT_JOIN = "M_UNABLE_TO_GRANT_JOIN"
class CodeMessageException(RuntimeError):
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 8dd33dcb83..697319e52d 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -168,7 +168,7 @@ class RoomVersions:
msc2403_knocking=False,
)
MSC3083 = RoomVersion(
- "org.matrix.msc3083",
+ "org.matrix.msc3083.v2",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 137dff2513..cc92d35477 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -106,6 +106,18 @@ def check(
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
+ is_invite_via_allow_rule = (
+ event.type == EventTypes.Member
+ and event.membership == Membership.JOIN
+ and "join_authorised_via_users_server" in event.content
+ )
+ if is_invite_via_allow_rule:
+ authoriser_domain = get_domain_from_id(
+ event.content["join_authorised_via_users_server"]
+ )
+ if not event.signatures.get(authoriser_domain):
+ raise AuthError(403, "Event not signed by authorising server")
+
# Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
#
# 1. If type is m.room.create:
@@ -177,7 +189,7 @@ def check(
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = get_user_power_level(event.user_id, auth_events)
- invite_level = _get_named_level(auth_events, "invite", 0)
+ invite_level = get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(403, "You don't have permission to invite users")
@@ -285,8 +297,8 @@ def _is_membership_change_allowed(
user_level = get_user_power_level(event.user_id, auth_events)
target_level = get_user_power_level(target_user_id, auth_events)
- # FIXME (erikj): What should we do here as the default?
- ban_level = _get_named_level(auth_events, "ban", 50)
+ invite_level = get_named_level(auth_events, "invite", 0)
+ ban_level = get_named_level(auth_events, "ban", 50)
logger.debug(
"_is_membership_change_allowed: %s",
@@ -336,8 +348,6 @@ def _is_membership_change_allowed(
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." % target_user_id)
else:
- invite_level = _get_named_level(auth_events, "invite", 0)
-
if user_level < invite_level:
raise AuthError(403, "You don't have permission to invite users")
elif Membership.JOIN == membership:
@@ -345,16 +355,41 @@ def _is_membership_change_allowed(
# * They are not banned.
# * They are accepting a previously sent invitation.
# * They are already joined (it's a NOOP).
- # * The room is public or restricted.
+ # * The room is public.
+ # * The room is restricted and the user meets the allows rules.
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
- elif join_rule == JoinRules.PUBLIC or (
+ elif join_rule == JoinRules.PUBLIC:
+ pass
+ elif (
room_version.msc3083_join_rules
and join_rule == JoinRules.MSC3083_RESTRICTED
):
- pass
+ # This is the same as public, but the event must contain a reference
+ # to the server who authorised the join. If the event does not contain
+ # the proper content it is rejected.
+ #
+ # Note that if the caller is in the room or invited, then they do
+ # not need to meet the allow rules.
+ if not caller_in_room and not caller_invited:
+ authorising_user = event.content.get("join_authorised_via_users_server")
+
+ if authorising_user is None:
+ raise AuthError(403, "Join event is missing authorising user.")
+
+ # The authorising user must be in the room.
+ key = (EventTypes.Member, authorising_user)
+ member_event = auth_events.get(key)
+ _check_joined_room(member_event, authorising_user, event.room_id)
+
+ authorising_user_level = get_user_power_level(
+ authorising_user, auth_events
+ )
+ if authorising_user_level < invite_level:
+ raise AuthError(403, "Join event authorised by invalid server.")
+
elif join_rule == JoinRules.INVITE or (
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
):
@@ -369,7 +404,7 @@ def _is_membership_change_allowed(
if target_banned and user_level < ban_level:
raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
elif target_user_id != event.user_id:
- kick_level = _get_named_level(auth_events, "kick", 50)
+ kick_level = get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(403, "You cannot kick user %s." % target_user_id)
@@ -445,7 +480,7 @@ def get_send_level(
def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
- power_levels_event = _get_power_level_event(auth_events)
+ power_levels_event = get_power_level_event(auth_events)
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
user_level = get_user_power_level(event.user_id, auth_events)
@@ -485,7 +520,7 @@ def check_redaction(
"""
user_level = get_user_power_level(event.user_id, auth_events)
- redact_level = _get_named_level(auth_events, "redact", 50)
+ redact_level = get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
@@ -600,7 +635,7 @@ def _check_power_levels(
)
-def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
+def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
return auth_events.get((EventTypes.PowerLevels, ""))
@@ -616,7 +651,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
Returns:
the user's power level in this room.
"""
- power_level_event = _get_power_level_event(auth_events)
+ power_level_event = get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
@@ -640,8 +675,8 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
return 0
-def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
- power_level_event = _get_power_level_event(auth_events)
+def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
+ power_level_event = get_power_level_event(auth_events)
if not power_level_event:
return default
@@ -728,7 +763,9 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
return public_keys
-def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str, str]]:
+def auth_types_for_event(
+ room_version: RoomVersion, event: Union[EventBase, EventBuilder]
+) -> Set[Tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
@@ -760,4 +797,12 @@ def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str
)
auth_types.add(key)
+ if room_version.msc3083_join_rules and membership == Membership.JOIN:
+ if "join_authorised_via_users_server" in event.content:
+ key = (
+ EventTypes.Member,
+ event.content["join_authorised_via_users_server"],
+ )
+ auth_types.add(key)
+
return auth_types
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2bfe6a3d37..024e440ff4 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -178,6 +178,34 @@ async def _check_sigs_on_pdu(
)
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+ # If this is a join event for a restricted room it may have been authorised
+ # via a different server from the sending server. Check those signatures.
+ if (
+ room_version.msc3083_join_rules
+ and pdu.type == EventTypes.Member
+ and pdu.membership == Membership.JOIN
+ and "join_authorised_via_users_server" in pdu.content
+ ):
+ authorising_server = get_domain_from_id(
+ pdu.content["join_authorised_via_users_server"]
+ )
+ try:
+ await keyring.verify_event_for_server(
+ authorising_server,
+ pdu,
+ pdu.origin_server_ts if room_version.enforce_key_validity else 0,
+ )
+ except Exception as e:
+ errmsg = (
+ "event id %s: unable to verify signature for authorising server %s: %s"
+ % (
+ pdu.event_id,
+ authorising_server,
+ e,
+ )
+ )
+ raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+
def _is_invite_via_3pid(event: EventBase) -> bool:
return (
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c767d30627..dbadf102f2 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -19,7 +19,6 @@ import itertools
import logging
from typing import (
TYPE_CHECKING,
- Any,
Awaitable,
Callable,
Collection,
@@ -79,7 +78,15 @@ class InvalidResponseError(RuntimeError):
we couldn't parse
"""
- pass
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SendJoinResult:
+ # The event to persist.
+ event: EventBase
+ # A string giving the server the event was sent to.
+ origin: str
+ state: List[EventBase]
+ auth_chain: List[EventBase]
class FederationClient(FederationBase):
@@ -677,7 +684,7 @@ class FederationClient(FederationBase):
async def send_join(
self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
- ) -> Dict[str, Any]:
+ ) -> SendJoinResult:
"""Sends a join event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph,
@@ -691,18 +698,38 @@ class FederationClient(FederationBase):
did the make_join)
Returns:
- a dict with members ``origin`` (a string
- giving the server the event was sent to, ``state`` (?) and
- ``auth_chain``.
+ The result of the send join request.
Raises:
SynapseError: if the chosen remote server returns a 300/400 code, or
no servers successfully handle the request.
"""
- async def send_request(destination) -> Dict[str, Any]:
+ async def send_request(destination) -> SendJoinResult:
response = await self._do_send_join(room_version, destination, pdu)
+ # If an event was returned (and expected to be returned):
+ #
+ # * Ensure it has the same event ID (note that the event ID is a hash
+ # of the event fields for versions which support MSC3083).
+ # * Ensure the signatures are good.
+ #
+ # Otherwise, fallback to the provided event.
+ if room_version.msc3083_join_rules and response.event:
+ event = response.event
+
+ valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+ pdu=event,
+ origin=destination,
+ outlier=True,
+ room_version=room_version,
+ )
+
+ if valid_pdu is None or event.event_id != pdu.event_id:
+ raise InvalidResponseError("Returned an invalid join event")
+ else:
+ event = pdu
+
state = response.state
auth_chain = response.auth_events
@@ -784,11 +811,21 @@ class FederationClient(FederationBase):
% (auth_chain_create_events,)
)
- return {
- "state": signed_state,
- "auth_chain": signed_auth,
- "origin": destination,
- }
+ return SendJoinResult(
+ event=event,
+ state=signed_state,
+ auth_chain=signed_auth,
+ origin=destination,
+ )
+
+ if room_version.msc3083_join_rules:
+ # If the join is being authorised via allow rules, we need to send
+ # the /send_join back to the same server that was originally used
+ # with /make_join.
+ if "join_authorised_via_users_server" in pdu.content:
+ destinations = [
+ get_domain_from_id(pdu.content["join_authorised_via_users_server"])
+ ]
return await self._try_destination_list("send_join", destinations, send_request)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 29619aeeb8..2892a11d7d 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.crypto.event_signing import compute_event_signature
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
@@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.storage.databases.main.lock import Lock
-from synapse.types import JsonDict
+from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -586,7 +587,7 @@ class FederationServer(FederationBase):
async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str
) -> Dict[str, Any]:
- context = await self._on_send_membership_event(
+ event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id
)
@@ -597,6 +598,7 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
return {
+ "org.matrix.msc3083.v2.event": event.get_pdu_json(),
"state": [p.get_pdu_json(time_now) for p in state.values()],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
}
@@ -681,7 +683,7 @@ class FederationServer(FederationBase):
Returns:
The stripped room state.
"""
- event_context = await self._on_send_membership_event(
+ _, context = await self._on_send_membership_event(
origin, content, Membership.KNOCK, room_id
)
@@ -690,14 +692,14 @@ class FederationServer(FederationBase):
# related to the room while the knock request is pending.
stripped_room_state = (
await self.store.get_stripped_room_state_from_event_context(
- event_context, self._room_prejoin_state_types
+ context, self._room_prejoin_state_types
)
)
return {"knock_state_events": stripped_room_state}
async def _on_send_membership_event(
self, origin: str, content: JsonDict, membership_type: str, room_id: str
- ) -> EventContext:
+ ) -> Tuple[EventBase, EventContext]:
"""Handle an on_send_{join,leave,knock} request
Does some preliminary validation before passing the request on to the
@@ -712,7 +714,7 @@ class FederationServer(FederationBase):
in the event
Returns:
- The context of the event after inserting it into the room graph.
+ The event and context of the event after inserting it into the room graph.
Raises:
SynapseError if there is a problem with the request, including things like
@@ -748,6 +750,33 @@ class FederationServer(FederationBase):
logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
+ # Sign the event since we're vouching on behalf of the remote server that
+ # the event is valid to be sent into the room. Currently this is only done
+ # if the user is being joined via restricted join rules.
+ if (
+ room_version.msc3083_join_rules
+ and event.membership == Membership.JOIN
+ and "join_authorised_via_users_server" in event.content
+ ):
+ # We can only authorise our own users.
+ authorising_server = get_domain_from_id(
+ event.content["join_authorised_via_users_server"]
+ )
+ if authorising_server != self.server_name:
+ raise SynapseError(
+ 400,
+ f"Cannot authorise request from resident server: {authorising_server}",
+ )
+
+ event.signatures.update(
+ compute_event_signature(
+ room_version,
+ event.get_pdu_json(),
+ self.hs.hostname,
+ self.hs.signing_key,
+ )
+ )
+
event = await self._check_sigs_and_hash(room_version, event)
return await self.handler.on_send_membership_event(origin, event)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index e73bdb52b3..6a8d3ad4fe 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1219,8 +1219,26 @@ def _create_v2_path(path: str, *args: str) -> str:
class SendJoinResponse:
"""The parsed response of a `/send_join` request."""
+ # The list of auth events from the /send_join response.
auth_events: List[EventBase]
+ # The list of state from the /send_join response.
state: List[EventBase]
+ # The raw join event from the /send_join response.
+ event_dict: JsonDict
+ # The parsed join event from the /send_join response. This will be None if
+ # "event" is not included in the response.
+ event: Optional[EventBase] = None
+
+
+@ijson.coroutine
+def _event_parser(event_dict: JsonDict):
+ """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
+ to add them to a given dictionary.
+ """
+
+ while True:
+ key, value = yield
+ event_dict[key] = value
@ijson.coroutine
@@ -1246,7 +1264,8 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
CONTENT_TYPE = "application/json"
def __init__(self, room_version: RoomVersion, v1_api: bool):
- self._response = SendJoinResponse([], [])
+ self._response = SendJoinResponse([], [], {})
+ self._room_version = room_version
# The V1 API has the shape of `[200, {...}]`, which we handle by
# prefixing with `item.*`.
@@ -1260,12 +1279,21 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
_event_list_parser(room_version, self._response.auth_events),
prefix + "auth_chain.item",
)
+ self._coro_event = ijson.kvitems_coro(
+ _event_parser(self._response.event_dict),
+ prefix + "org.matrix.msc3083.v2.event",
+ )
def write(self, data: bytes) -> int:
self._coro_state.send(data)
self._coro_auth.send(data)
+ self._coro_event.send(data)
return len(data)
def finish(self) -> SendJoinResponse:
+ if self._response.event_dict:
+ self._response.event = make_event_from_dict(
+ self._response.event_dict, self._room_version
+ )
return self._response
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 41dbdfd0a1..53fac1f8a3 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from typing import TYPE_CHECKING, Collection, List, Optional, Union
from synapse import event_auth
@@ -20,16 +21,18 @@ from synapse.api.constants import (
Membership,
RestrictedJoinRuleTypes,
)
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
-from synapse.types import StateMap
+from synapse.types import StateMap, get_domain_from_id
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
+logger = logging.getLogger(__name__)
+
class EventAuthHandler:
"""
@@ -39,6 +42,7 @@ class EventAuthHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastore()
+ self._server_name = hs.hostname
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
@@ -81,15 +85,76 @@ class EventAuthHandler:
# introduce undesirable "state reset" behaviour.
#
# All of which sounds a bit tricky so we don't bother for now.
-
auth_ids = []
- for etype, state_key in event_auth.auth_types_for_event(event):
+ for etype, state_key in event_auth.auth_types_for_event(
+ event.room_version, event
+ ):
auth_ev_id = current_state_ids.get((etype, state_key))
if auth_ev_id:
auth_ids.append(auth_ev_id)
return auth_ids
+ async def get_user_which_could_invite(
+ self, room_id: str, current_state_ids: StateMap[str]
+ ) -> str:
+ """
+ Searches the room state for a local user who has the power level necessary
+ to invite other users.
+
+ Args:
+ room_id: The room ID under search.
+ current_state_ids: The current state of the room.
+
+ Returns:
+ The MXID of the user which could issue an invite.
+
+ Raises:
+ SynapseError if no appropriate user is found.
+ """
+ power_level_event_id = current_state_ids.get((EventTypes.PowerLevels, ""))
+ invite_level = 0
+ users_default_level = 0
+ if power_level_event_id:
+ power_level_event = await self._store.get_event(power_level_event_id)
+ invite_level = power_level_event.content.get("invite", invite_level)
+ users_default_level = power_level_event.content.get(
+ "users_default", users_default_level
+ )
+ users = power_level_event.content.get("users", {})
+ else:
+ users = {}
+
+ # Find the user with the highest power level.
+ users_in_room = await self._store.get_users_in_room(room_id)
+ # Only interested in local users.
+ local_users_in_room = [
+ u for u in users_in_room if get_domain_from_id(u) == self._server_name
+ ]
+ chosen_user = max(
+ local_users_in_room,
+ key=lambda user: users.get(user, users_default_level),
+ default=None,
+ )
+
+ # Return the chosen if they can issue invites.
+ user_power_level = users.get(chosen_user, users_default_level)
+ if chosen_user and user_power_level >= invite_level:
+ logger.debug(
+ "Found a user who can issue invites %s with power level %d >= invite level %d",
+ chosen_user,
+ user_power_level,
+ invite_level,
+ )
+ return chosen_user
+
+ # No user was found.
+ raise SynapseError(
+ 400,
+ "Unable to find a user which could issue an invite",
+ Codes.UNABLE_TO_GRANT_JOIN,
+ )
+
async def check_host_in_room(self, room_id: str, host: str) -> bool:
with Measure(self._clock, "check_host_in_room"):
return await self._store.is_host_joined(room_id, host)
@@ -134,6 +199,18 @@ class EventAuthHandler:
# in any of them.
allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
if not await self.is_user_in_rooms(allowed_rooms, user_id):
+
+ # If this is a remote request, the user might be in an allowed room
+ # that we do not know about.
+ if get_domain_from_id(user_id) != self._server_name:
+ for room_id in allowed_rooms:
+ if not await self._store.is_host_joined(room_id, self._server_name):
+ raise SynapseError(
+ 400,
+ f"Unable to check if {user_id} is in allowed rooms.",
+ Codes.UNABLE_AUTHORISE_JOIN,
+ )
+
raise AuthError(
403,
"You do not belong to any of the required rooms to join this room.",
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 5728719909..aba095d2e1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1494,9 +1494,10 @@ class FederationHandler(BaseHandler):
host_list, event, room_version_obj
)
- origin = ret["origin"]
- state = ret["state"]
- auth_chain = ret["auth_chain"]
+ event = ret.event
+ origin = ret.origin
+ state = ret.state
+ auth_chain = ret.auth_chain
auth_chain.sort(key=lambda e: e.depth)
logger.debug("do_invite_join auth_chain: %s", auth_chain)
@@ -1676,7 +1677,7 @@ class FederationHandler(BaseHandler):
# checking the room version will check that we've actually heard of the room
# (and return a 404 otherwise)
- room_version = await self.store.get_room_version_id(room_id)
+ room_version = await self.store.get_room_version(room_id)
# now check that we are *still* in the room
is_in_room = await self._event_auth_handler.check_host_in_room(
@@ -1691,8 +1692,38 @@ class FederationHandler(BaseHandler):
event_content = {"membership": Membership.JOIN}
+ # If the current room is using restricted join rules, additional information
+ # may need to be included in the event content in order to efficiently
+ # validate the event.
+ #
+ # Note that this requires the /send_join request to come back to the
+ # same server.
+ if room_version.msc3083_join_rules:
+ state_ids = await self.store.get_current_state_ids(room_id)
+ if await self._event_auth_handler.has_restricted_join_rules(
+ state_ids, room_version
+ ):
+ prev_member_event_id = state_ids.get((EventTypes.Member, user_id), None)
+ # If the user is invited or joined to the room already, then
+ # no additional info is needed.
+ include_auth_user_id = True
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+ include_auth_user_id = prev_member_event.membership not in (
+ Membership.JOIN,
+ Membership.INVITE,
+ )
+
+ if include_auth_user_id:
+ event_content[
+ "join_authorised_via_users_server"
+ ] = await self._event_auth_handler.get_user_which_could_invite(
+ room_id,
+ state_ids,
+ )
+
builder = self.event_builder_factory.new(
- room_version,
+ room_version.identifier,
{
"type": EventTypes.Member,
"content": event_content,
@@ -1710,10 +1741,13 @@ class FederationHandler(BaseHandler):
logger.warning("Failed to create join to %s because %s", room_id, e)
raise
+ # Ensure the user can even join the room.
+ await self._check_join_restrictions(context, event)
+
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
await self._event_auth_handler.check_from_context(
- room_version, event, context, do_sig_check=False
+ room_version.identifier, event, context, do_sig_check=False
)
return event
@@ -1958,7 +1992,7 @@ class FederationHandler(BaseHandler):
@log_function
async def on_send_membership_event(
self, origin: str, event: EventBase
- ) -> EventContext:
+ ) -> Tuple[EventBase, EventContext]:
"""
We have received a join/leave/knock event for a room via send_join/leave/knock.
@@ -1981,7 +2015,7 @@ class FederationHandler(BaseHandler):
event: The member event that has been signed by the remote homeserver.
Returns:
- The context of the event after inserting it into the room graph.
+ The event and context of the event after inserting it into the room graph.
Raises:
SynapseError if the event is not accepted into the room
@@ -2037,7 +2071,7 @@ class FederationHandler(BaseHandler):
# all looks good, we can persist the event.
await self._run_push_actions_and_persist_event(event, context)
- return context
+ return event, context
async def _check_join_restrictions(
self, context: EventContext, event: EventBase
@@ -2473,7 +2507,7 @@ class FederationHandler(BaseHandler):
)
# Now check if event pass auth against said current state
- auth_types = auth_types_for_event(event)
+ auth_types = auth_types_for_event(room_version_obj, event)
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1192591609..65ad3efa6a 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -16,7 +16,7 @@ import abc
import logging
import random
from http import HTTPStatus
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
from synapse import types
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
@@ -28,6 +28,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
+from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import (
@@ -340,16 +341,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if event.membership == Membership.JOIN:
newly_joined = True
- prev_member_event = None
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
- # Check if the member should be allowed access via membership in a space.
- await self.event_auth_handler.check_restricted_join_rules(
- prev_state_ids, event.room_version, user_id, prev_member_event
- )
-
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
if newly_joined and ratelimit:
@@ -701,7 +696,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
- if not is_host_in_room:
+ # Check if a remote join should be performed.
+ remote_join, remote_room_hosts = await self._should_perform_remote_join(
+ target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
+ )
+ if remote_join:
if ratelimit:
time_now_s = self.clock.time()
(
@@ -826,6 +825,106 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
outlier=outlier,
)
+ async def _should_perform_remote_join(
+ self,
+ user_id: str,
+ room_id: str,
+ remote_room_hosts: List[str],
+ content: JsonDict,
+ is_host_in_room: bool,
+ ) -> Tuple[bool, List[str]]:
+ """
+ Check whether the server should do a remote join (as opposed to a local
+ join) for a user.
+
+ Generally a remote join is used if:
+
+ * The server is not yet in the room.
+ * The server is in the room, the room has restricted join rules, the user
+ is not joined or invited to the room, and the server does not have
+ another user who is capable of issuing invites.
+
+ Args:
+ user_id: The user joining the room.
+ room_id: The room being joined.
+ remote_room_hosts: A list of remote room hosts.
+ content: The content to use as the event body of the join. This may
+ be modified.
+ is_host_in_room: True if the host is in the room.
+
+ Returns:
+ A tuple of:
+ True if a remote join should be performed. False if the join can be
+ done locally.
+
+ A list of remote room hosts to use. This is an empty list if a
+ local join is to be done.
+ """
+ # If the host isn't in the room, pass through the prospective hosts.
+ if not is_host_in_room:
+ return True, remote_room_hosts
+
+ # If the host is in the room, but not one of the authorised hosts
+ # for restricted join rules, a remote join must be used.
+ room_version = await self.store.get_room_version(room_id)
+ current_state_ids = await self.store.get_current_state_ids(room_id)
+
+ # If restricted join rules are not being used, a local join can always
+ # be used.
+ if not await self.event_auth_handler.has_restricted_join_rules(
+ current_state_ids, room_version
+ ):
+ return False, []
+
+ # If the user is invited to the room or already joined, the join
+ # event can always be issued locally.
+ prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
+ prev_member_event = None
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+ if prev_member_event.membership in (
+ Membership.JOIN,
+ Membership.INVITE,
+ ):
+ return False, []
+
+ # If the local host has a user who can issue invites, then a local
+ # join can be done.
+ #
+ # If not, generate a new list of remote hosts based on which
+ # can issue invites.
+ event_map = await self.store.get_events(current_state_ids.values())
+ current_state = {
+ state_key: event_map[event_id]
+ for state_key, event_id in current_state_ids.items()
+ }
+ allowed_servers = get_servers_from_users(
+ get_users_which_can_issue_invite(current_state)
+ )
+
+ # If the local server is not one of allowed servers, then a remote
+ # join must be done. Return the list of prospective servers based on
+ # which can issue invites.
+ if self.hs.hostname not in allowed_servers:
+ return True, list(allowed_servers)
+
+ # Ensure the member should be allowed access via membership in a room.
+ await self.event_auth_handler.check_restricted_join_rules(
+ current_state_ids, room_version, user_id, prev_member_event
+ )
+
+ # If this is going to be a local join, additional information must
+ # be included in the event content in order to efficiently validate
+ # the event.
+ content[
+ "join_authorised_via_users_server"
+ ] = await self.event_auth_handler.get_user_which_could_invite(
+ room_id,
+ current_state_ids,
+ )
+
+ return False, []
+
async def transfer_room_state_on_room_upgrade(
self, old_room_id: str, room_id: str
) -> None:
@@ -1514,3 +1613,63 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if membership:
await self.store.forget(user_id, room_id)
+
+
+def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]:
+ """
+ Return the list of users which can issue invites.
+
+ This is done by exploring the joined users and comparing their power levels
+ to the necessyar power level to issue an invite.
+
+ Args:
+ auth_events: state in force at this point in the room
+
+ Returns:
+ The users which can issue invites.
+ """
+ invite_level = get_named_level(auth_events, "invite", 0)
+ users_default_level = get_named_level(auth_events, "users_default", 0)
+ power_level_event = get_power_level_event(auth_events)
+
+ # Custom power-levels for users.
+ if power_level_event:
+ users = power_level_event.content.get("users", {})
+ else:
+ users = {}
+
+ result = []
+
+ # Check which members are able to invite by ensuring they're joined and have
+ # the necessary power level.
+ for (event_type, state_key), event in auth_events.items():
+ if event_type != EventTypes.Member:
+ continue
+
+ if event.membership != Membership.JOIN:
+ continue
+
+ # Check if the user has a custom power level.
+ if users.get(state_key, users_default_level) >= invite_level:
+ result.append(state_key)
+
+ return result
+
+
+def get_servers_from_users(users: List[str]) -> Set[str]:
+ """
+ Resolve a list of users into their servers.
+
+ Args:
+ users: A list of users.
+
+ Returns:
+ A set of servers.
+ """
+ servers = set()
+ for user in users:
+ try:
+ servers.add(get_domain_from_id(user))
+ except SynapseError:
+ pass
+ return servers
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 6223daf522..2e15471435 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -636,16 +636,20 @@ class StateResolutionHandler:
"""
try:
with Measure(self.clock, "state._resolve_events") as m:
- v = KNOWN_ROOM_VERSIONS[room_version]
- if v.state_res == StateResolutionVersions.V1:
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+ if room_version_obj.state_res == StateResolutionVersions.V1:
return await v1.resolve_events_with_store(
- room_id, state_sets, event_map, state_res_store.get_events
+ room_id,
+ room_version_obj,
+ state_sets,
+ event_map,
+ state_res_store.get_events,
)
else:
return await v2.resolve_events_with_store(
self.clock,
room_id,
- room_version,
+ room_version_obj,
state_sets,
event_map,
state_res_store,
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 267193cedf..92336d7cc8 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -29,7 +29,7 @@ from typing import (
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
@@ -41,6 +41,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
async def resolve_events_with_store(
room_id: str,
+ room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
@@ -104,7 +105,7 @@ async def resolve_events_with_store(
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
auth_events = _create_auth_events_from_maps(
- unconflicted_state, conflicted_state, state_map
+ room_version, unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(auth_events.values())
@@ -132,7 +133,7 @@ async def resolve_events_with_store(
state_map.update(state_map_new)
return _resolve_with_state(
- unconflicted_state, conflicted_state, auth_events, state_map
+ room_version, unconflicted_state, conflicted_state, auth_events, state_map
)
@@ -187,6 +188,7 @@ def _seperate(
def _create_auth_events_from_maps(
+ room_version: RoomVersion,
unconflicted_state: StateMap[str],
conflicted_state: StateMap[Set[str]],
state_map: Dict[str, EventBase],
@@ -194,6 +196,7 @@ def _create_auth_events_from_maps(
"""
Args:
+ room_version: The room version.
unconflicted_state: The unconflicted state map.
conflicted_state: The conflicted state map.
state_map:
@@ -205,7 +208,9 @@ def _create_auth_events_from_maps(
for event_ids in conflicted_state.values():
for event_id in event_ids:
if event_id in state_map:
- keys = event_auth.auth_types_for_event(state_map[event_id])
+ keys = event_auth.auth_types_for_event(
+ room_version, state_map[event_id]
+ )
for key in keys:
if key not in auth_events:
auth_event_id = unconflicted_state.get(key, None)
@@ -215,6 +220,7 @@ def _create_auth_events_from_maps(
def _resolve_with_state(
+ room_version: RoomVersion,
unconflicted_state_ids: MutableStateMap[str],
conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str],
@@ -235,7 +241,9 @@ def _resolve_with_state(
}
try:
- resolved_state = _resolve_state_events(conflicted_state, auth_events)
+ resolved_state = _resolve_state_events(
+ room_version, conflicted_state, auth_events
+ )
except Exception:
logger.exception("Failed to resolve state")
raise
@@ -248,7 +256,9 @@ def _resolve_with_state(
def _resolve_state_events(
- conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
+ room_version: RoomVersion,
+ conflicted_state: StateMap[List[EventBase]],
+ auth_events: MutableStateMap[EventBase],
) -> StateMap[EventBase]:
"""This is where we actually decide which of the conflicted state to
use.
@@ -263,21 +273,27 @@ def _resolve_state_events(
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
- resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events)
+ resolved_state[POWER_KEY] = _resolve_auth_events(
+ room_version, events, auth_events
+ )
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
- resolved_state[key] = _resolve_auth_events(events, auth_events)
+ resolved_state[key] = _resolve_auth_events(
+ room_version, events, auth_events
+ )
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
- resolved_state[key] = _resolve_auth_events(events, auth_events)
+ resolved_state[key] = _resolve_auth_events(
+ room_version, events, auth_events
+ )
auth_events.update(resolved_state)
@@ -290,12 +306,14 @@ def _resolve_state_events(
def _resolve_auth_events(
- events: List[EventBase], auth_events: StateMap[EventBase]
+ room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase:
reverse = list(reversed(_ordered_events(events)))
auth_keys = {
- key for event in events for key in event_auth.auth_types_for_event(event)
+ key
+ for event in events
+ for key in event_auth.auth_types_for_event(room_version, event)
}
new_auth_events = {}
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index e66e6571c8..7b1e8361de 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -36,7 +36,7 @@ import synapse.state
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock
@@ -53,7 +53,7 @@ _AWAIT_AFTER_ITERATIONS = 100
async def resolve_events_with_store(
clock: Clock,
room_id: str,
- room_version: str,
+ room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
@@ -497,7 +497,7 @@ async def _reverse_topological_power_sort(
async def _iterative_auth_checks(
clock: Clock,
room_id: str,
- room_version: str,
+ room_version: RoomVersion,
event_ids: List[str],
base_state: StateMap[str],
event_map: Dict[str, EventBase],
@@ -519,7 +519,6 @@ async def _iterative_auth_checks(
Returns the final updated state
"""
resolved_state = dict(base_state)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
for idx, event_id in enumerate(event_ids, start=1):
event = event_map[event_id]
@@ -538,7 +537,7 @@ async def _iterative_auth_checks(
if ev.rejected_reason is None:
auth_events[(ev.type, ev.state_key)] = ev
- for key in event_auth.auth_types_for_event(event):
+ for key in event_auth.auth_types_for_event(room_version, event):
if key in resolved_state:
ev_id = resolved_state[key]
ev = await _get_event(room_id, ev_id, event_map, state_res_store)
@@ -548,7 +547,7 @@ async def _iterative_auth_checks(
try:
event_auth.check(
- room_version_obj,
+ room_version,
event,
auth_events,
do_sig_check=False,
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 43fc79ca74..8370a27195 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -484,7 +484,7 @@ class StateTestCase(unittest.TestCase):
state_d = resolve_events_with_store(
FakeClock(),
ROOM_ID,
- RoomVersions.V2.identifier,
+ RoomVersions.V2,
[state_at_event[n] for n in prev_events],
event_map=event_map,
state_res_store=TestStateResolutionStore(event_map),
@@ -496,7 +496,7 @@ class StateTestCase(unittest.TestCase):
if fake_event.state_key is not None:
state_after[(fake_event.type, fake_event.state_key)] = event_id
- auth_types = set(auth_types_for_event(fake_event))
+ auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))
auth_events = []
for key in auth_types:
@@ -633,7 +633,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
state_d = resolve_events_with_store(
FakeClock(),
ROOM_ID,
- RoomVersions.V2.identifier,
+ RoomVersions.V2,
[self.state_at_bob, self.state_at_charlie],
event_map=None,
state_res_store=TestStateResolutionStore(self.event_map),
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index dbacce4380..8c95a0a2fb 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import List, Optional
from canonicaljson import json
@@ -234,8 +234,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
async def build(
self,
- prev_event_ids,
- auth_event_ids,
+ prev_event_ids: List[str],
+ auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
):
built_event = await self._base_builder.build(
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index f73306ecc4..e5550aec4d 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -351,7 +351,11 @@ class EventAuthTestCase(unittest.TestCase):
"""
Test joining a restricted room from MSC3083.
- This is pretty much the same test as public.
+ This is similar to the public test, but has some additional checks on
+ signatures.
+
+ The checks which care about signatures fake them by simply adding an
+ object of the proper form, not generating valid signatures.
"""
creator = "@creator:example.com"
pleb = "@joiner:example.com"
@@ -359,6 +363,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events = {
("m.room.create", ""): _create_event(creator),
("m.room.member", creator): _join_event(creator),
+ ("m.room.power_levels", ""): _power_levels_event(creator, {"invite": 0}),
("m.room.join_rules", ""): _join_rules_event(creator, "restricted"),
}
@@ -371,19 +376,81 @@ class EventAuthTestCase(unittest.TestCase):
do_sig_check=False,
)
- # Check join.
+ # A properly formatted join event should work.
+ authorised_join_event = _join_event(
+ pleb,
+ additional_content={
+ "join_authorised_via_users_server": "@creator:example.com"
+ },
+ )
event_auth.check(
RoomVersions.MSC3083,
- _join_event(pleb),
+ authorised_join_event,
auth_events,
do_sig_check=False,
)
- # A user cannot be force-joined to a room.
+ # A join issued by a specific user works (i.e. the power level checks
+ # are done properly).
+ pl_auth_events = auth_events.copy()
+ pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
+ creator, {"invite": 100, "users": {"@inviter:foo.test": 150}}
+ )
+ pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event(
+ "@inviter:foo.test"
+ )
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(
+ pleb,
+ additional_content={
+ "join_authorised_via_users_server": "@inviter:foo.test"
+ },
+ ),
+ pl_auth_events,
+ do_sig_check=False,
+ )
+
+ # A join which is missing an authorised server is rejected.
with self.assertRaises(AuthError):
event_auth.check(
RoomVersions.MSC3083,
- _member_event(pleb, "join", sender=creator),
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # An join authorised by a user who is not in the room is rejected.
+ pl_auth_events = auth_events.copy()
+ pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
+ creator, {"invite": 100, "users": {"@other:example.com": 150}}
+ )
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(
+ pleb,
+ additional_content={
+ "join_authorised_via_users_server": "@other:example.com"
+ },
+ ),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user cannot be force-joined to a room. (This uses an event which
+ # *would* be valid, but is sent be a different user.)
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _member_event(
+ pleb,
+ "join",
+ sender=creator,
+ additional_content={
+ "join_authorised_via_users_server": "@inviter:foo.test"
+ },
+ ),
auth_events,
do_sig_check=False,
)
@@ -393,7 +460,7 @@ class EventAuthTestCase(unittest.TestCase):
with self.assertRaises(AuthError):
event_auth.check(
RoomVersions.MSC3083,
- _join_event(pleb),
+ authorised_join_event,
auth_events,
do_sig_check=False,
)
@@ -402,12 +469,13 @@ class EventAuthTestCase(unittest.TestCase):
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
event_auth.check(
RoomVersions.MSC3083,
- _join_event(pleb),
+ authorised_join_event,
auth_events,
do_sig_check=False,
)
- # A user can send a join if they're in the room.
+ # A user can send a join if they're in the room. (This doesn't need to
+ # be authorised since the user is already joined.)
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
event_auth.check(
RoomVersions.MSC3083,
@@ -416,7 +484,8 @@ class EventAuthTestCase(unittest.TestCase):
do_sig_check=False,
)
- # A user can accept an invite.
+ # A user can accept an invite. (This doesn't need to be authorised since
+ # the user was invited.)
auth_events[("m.room.member", pleb)] = _member_event(
pleb, "invite", sender=creator
)
@@ -446,7 +515,10 @@ def _create_event(user_id: str) -> EventBase:
def _member_event(
- user_id: str, membership: str, sender: Optional[str] = None
+ user_id: str,
+ membership: str,
+ sender: Optional[str] = None,
+ additional_content: Optional[dict] = None,
) -> EventBase:
return make_event_from_dict(
{
@@ -455,14 +527,14 @@ def _member_event(
"type": "m.room.member",
"sender": sender or user_id,
"state_key": user_id,
- "content": {"membership": membership},
+ "content": {"membership": membership, **(additional_content or {})},
"prev_events": [],
}
)
-def _join_event(user_id: str) -> EventBase:
- return _member_event(user_id, "join")
+def _join_event(user_id: str, additional_content: Optional[dict] = None) -> EventBase:
+ return _member_event(user_id, "join", additional_content=additional_content)
def _power_levels_event(sender: str, content: JsonDict) -> EventBase:
|