summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-07-26 12:17:00 -0400
committerGitHub <noreply@github.com>2021-07-26 12:17:00 -0400
commit228decfce1a71651d64c359d1cf28e10d0a69fc8 (patch)
treee061e3a8c8cc49b0fefefb82ad586ef7229a3dbb /synapse
parentAdd type hints to synapse.federation.transport.client. (#10408) (diff)
downloadsynapse-228decfce1a71651d64c359d1cf28e10d0a69fc8.tar.xz
Update the MSC3083 support to verify if joins are from an authorized server. (#10254)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/errors.py3
-rw-r--r--synapse/api/room_versions.py2
-rw-r--r--synapse/event_auth.py77
-rw-r--r--synapse/federation/federation_base.py28
-rw-r--r--synapse/federation/federation_client.py61
-rw-r--r--synapse/federation/federation_server.py41
-rw-r--r--synapse/federation/transport/client.py30
-rw-r--r--synapse/handlers/event_auth.py85
-rw-r--r--synapse/handlers/federation.py54
-rw-r--r--synapse/handlers/room_member.py175
-rw-r--r--synapse/state/__init__.py12
-rw-r--r--synapse/state/v1.py40
-rw-r--r--synapse/state/v2.py11
13 files changed, 540 insertions, 79 deletions
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 054ab14ab6..dc662bca83 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -75,6 +75,9 @@ class Codes:
     INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
     USER_DEACTIVATED = "M_USER_DEACTIVATED"
     BAD_ALIAS = "M_BAD_ALIAS"
+    # For restricted join rules.
+    UNABLE_AUTHORISE_JOIN = "M_UNABLE_TO_AUTHORISE_JOIN"
+    UNABLE_TO_GRANT_JOIN = "M_UNABLE_TO_GRANT_JOIN"
 
 
 class CodeMessageException(RuntimeError):
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 8dd33dcb83..697319e52d 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -168,7 +168,7 @@ class RoomVersions:
         msc2403_knocking=False,
     )
     MSC3083 = RoomVersion(
-        "org.matrix.msc3083",
+        "org.matrix.msc3083.v2",
         RoomDisposition.UNSTABLE,
         EventFormatVersions.V3,
         StateResolutionVersions.V2,
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 137dff2513..cc92d35477 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -106,6 +106,18 @@ def check(
             if not event.signatures.get(event_id_domain):
                 raise AuthError(403, "Event not signed by sending server")
 
+        is_invite_via_allow_rule = (
+            event.type == EventTypes.Member
+            and event.membership == Membership.JOIN
+            and "join_authorised_via_users_server" in event.content
+        )
+        if is_invite_via_allow_rule:
+            authoriser_domain = get_domain_from_id(
+                event.content["join_authorised_via_users_server"]
+            )
+            if not event.signatures.get(authoriser_domain):
+                raise AuthError(403, "Event not signed by authorising server")
+
     # Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
     #
     # 1. If type is m.room.create:
@@ -177,7 +189,7 @@ def check(
     # https://github.com/vector-im/vector-web/issues/1208 hopefully
     if event.type == EventTypes.ThirdPartyInvite:
         user_level = get_user_power_level(event.user_id, auth_events)
-        invite_level = _get_named_level(auth_events, "invite", 0)
+        invite_level = get_named_level(auth_events, "invite", 0)
 
         if user_level < invite_level:
             raise AuthError(403, "You don't have permission to invite users")
@@ -285,8 +297,8 @@ def _is_membership_change_allowed(
     user_level = get_user_power_level(event.user_id, auth_events)
     target_level = get_user_power_level(target_user_id, auth_events)
 
-    # FIXME (erikj): What should we do here as the default?
-    ban_level = _get_named_level(auth_events, "ban", 50)
+    invite_level = get_named_level(auth_events, "invite", 0)
+    ban_level = get_named_level(auth_events, "ban", 50)
 
     logger.debug(
         "_is_membership_change_allowed: %s",
@@ -336,8 +348,6 @@ def _is_membership_change_allowed(
         elif target_in_room:  # the target is already in the room.
             raise AuthError(403, "%s is already in the room." % target_user_id)
         else:
-            invite_level = _get_named_level(auth_events, "invite", 0)
-
             if user_level < invite_level:
                 raise AuthError(403, "You don't have permission to invite users")
     elif Membership.JOIN == membership:
@@ -345,16 +355,41 @@ def _is_membership_change_allowed(
         # * They are not banned.
         # * They are accepting a previously sent invitation.
         # * They are already joined (it's a NOOP).
-        # * The room is public or restricted.
+        # * The room is public.
+        # * The room is restricted and the user meets the allows rules.
         if event.user_id != target_user_id:
             raise AuthError(403, "Cannot force another user to join.")
         elif target_banned:
             raise AuthError(403, "You are banned from this room")
-        elif join_rule == JoinRules.PUBLIC or (
+        elif join_rule == JoinRules.PUBLIC:
+            pass
+        elif (
             room_version.msc3083_join_rules
             and join_rule == JoinRules.MSC3083_RESTRICTED
         ):
-            pass
+            # This is the same as public, but the event must contain a reference
+            # to the server who authorised the join. If the event does not contain
+            # the proper content it is rejected.
+            #
+            # Note that if the caller is in the room or invited, then they do
+            # not need to meet the allow rules.
+            if not caller_in_room and not caller_invited:
+                authorising_user = event.content.get("join_authorised_via_users_server")
+
+                if authorising_user is None:
+                    raise AuthError(403, "Join event is missing authorising user.")
+
+                # The authorising user must be in the room.
+                key = (EventTypes.Member, authorising_user)
+                member_event = auth_events.get(key)
+                _check_joined_room(member_event, authorising_user, event.room_id)
+
+                authorising_user_level = get_user_power_level(
+                    authorising_user, auth_events
+                )
+                if authorising_user_level < invite_level:
+                    raise AuthError(403, "Join event authorised by invalid server.")
+
         elif join_rule == JoinRules.INVITE or (
             room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
         ):
@@ -369,7 +404,7 @@ def _is_membership_change_allowed(
         if target_banned and user_level < ban_level:
             raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
         elif target_user_id != event.user_id:
-            kick_level = _get_named_level(auth_events, "kick", 50)
+            kick_level = get_named_level(auth_events, "kick", 50)
 
             if user_level < kick_level or user_level <= target_level:
                 raise AuthError(403, "You cannot kick user %s." % target_user_id)
@@ -445,7 +480,7 @@ def get_send_level(
 
 
 def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
-    power_levels_event = _get_power_level_event(auth_events)
+    power_levels_event = get_power_level_event(auth_events)
 
     send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
     user_level = get_user_power_level(event.user_id, auth_events)
@@ -485,7 +520,7 @@ def check_redaction(
     """
     user_level = get_user_power_level(event.user_id, auth_events)
 
-    redact_level = _get_named_level(auth_events, "redact", 50)
+    redact_level = get_named_level(auth_events, "redact", 50)
 
     if user_level >= redact_level:
         return False
@@ -600,7 +635,7 @@ def _check_power_levels(
             )
 
 
-def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
+def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
     return auth_events.get((EventTypes.PowerLevels, ""))
 
 
@@ -616,7 +651,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
     Returns:
         the user's power level in this room.
     """
-    power_level_event = _get_power_level_event(auth_events)
+    power_level_event = get_power_level_event(auth_events)
     if power_level_event:
         level = power_level_event.content.get("users", {}).get(user_id)
         if not level:
@@ -640,8 +675,8 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
             return 0
 
 
-def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
-    power_level_event = _get_power_level_event(auth_events)
+def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
+    power_level_event = get_power_level_event(auth_events)
 
     if not power_level_event:
         return default
@@ -728,7 +763,9 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
     return public_keys
 
 
-def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str, str]]:
+def auth_types_for_event(
+    room_version: RoomVersion, event: Union[EventBase, EventBuilder]
+) -> Set[Tuple[str, str]]:
     """Given an event, return a list of (EventType, StateKey) that may be
     needed to auth the event. The returned list may be a superset of what
     would actually be required depending on the full state of the room.
@@ -760,4 +797,12 @@ def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str
                 )
                 auth_types.add(key)
 
+        if room_version.msc3083_join_rules and membership == Membership.JOIN:
+            if "join_authorised_via_users_server" in event.content:
+                key = (
+                    EventTypes.Member,
+                    event.content["join_authorised_via_users_server"],
+                )
+                auth_types.add(key)
+
     return auth_types
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2bfe6a3d37..024e440ff4 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -178,6 +178,34 @@ async def _check_sigs_on_pdu(
             )
             raise SynapseError(403, errmsg, Codes.FORBIDDEN)
 
+    # If this is a join event for a restricted room it may have been authorised
+    # via a different server from the sending server. Check those signatures.
+    if (
+        room_version.msc3083_join_rules
+        and pdu.type == EventTypes.Member
+        and pdu.membership == Membership.JOIN
+        and "join_authorised_via_users_server" in pdu.content
+    ):
+        authorising_server = get_domain_from_id(
+            pdu.content["join_authorised_via_users_server"]
+        )
+        try:
+            await keyring.verify_event_for_server(
+                authorising_server,
+                pdu,
+                pdu.origin_server_ts if room_version.enforce_key_validity else 0,
+            )
+        except Exception as e:
+            errmsg = (
+                "event id %s: unable to verify signature for authorising server %s: %s"
+                % (
+                    pdu.event_id,
+                    authorising_server,
+                    e,
+                )
+            )
+            raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+
 
 def _is_invite_via_3pid(event: EventBase) -> bool:
     return (
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c767d30627..dbadf102f2 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -19,7 +19,6 @@ import itertools
 import logging
 from typing import (
     TYPE_CHECKING,
-    Any,
     Awaitable,
     Callable,
     Collection,
@@ -79,7 +78,15 @@ class InvalidResponseError(RuntimeError):
     we couldn't parse
     """
 
-    pass
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SendJoinResult:
+    # The event to persist.
+    event: EventBase
+    # A string giving the server the event was sent to.
+    origin: str
+    state: List[EventBase]
+    auth_chain: List[EventBase]
 
 
 class FederationClient(FederationBase):
@@ -677,7 +684,7 @@ class FederationClient(FederationBase):
 
     async def send_join(
         self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
-    ) -> Dict[str, Any]:
+    ) -> SendJoinResult:
         """Sends a join event to one of a list of homeservers.
 
         Doing so will cause the remote server to add the event to the graph,
@@ -691,18 +698,38 @@ class FederationClient(FederationBase):
                 did the make_join)
 
         Returns:
-            a dict with members ``origin`` (a string
-            giving the server the event was sent to, ``state`` (?) and
-            ``auth_chain``.
+            The result of the send join request.
 
         Raises:
             SynapseError: if the chosen remote server returns a 300/400 code, or
                 no servers successfully handle the request.
         """
 
-        async def send_request(destination) -> Dict[str, Any]:
+        async def send_request(destination) -> SendJoinResult:
             response = await self._do_send_join(room_version, destination, pdu)
 
+            # If an event was returned (and expected to be returned):
+            #
+            # * Ensure it has the same event ID (note that the event ID is a hash
+            #   of the event fields for versions which support MSC3083).
+            # * Ensure the signatures are good.
+            #
+            # Otherwise, fallback to the provided event.
+            if room_version.msc3083_join_rules and response.event:
+                event = response.event
+
+                valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+                    pdu=event,
+                    origin=destination,
+                    outlier=True,
+                    room_version=room_version,
+                )
+
+                if valid_pdu is None or event.event_id != pdu.event_id:
+                    raise InvalidResponseError("Returned an invalid join event")
+            else:
+                event = pdu
+
             state = response.state
             auth_chain = response.auth_events
 
@@ -784,11 +811,21 @@ class FederationClient(FederationBase):
                     % (auth_chain_create_events,)
                 )
 
-            return {
-                "state": signed_state,
-                "auth_chain": signed_auth,
-                "origin": destination,
-            }
+            return SendJoinResult(
+                event=event,
+                state=signed_state,
+                auth_chain=signed_auth,
+                origin=destination,
+            )
+
+        if room_version.msc3083_join_rules:
+            # If the join is being authorised via allow rules, we need to send
+            # the /send_join back to the same server that was originally used
+            # with /make_join.
+            if "join_authorised_via_users_server" in pdu.content:
+                destinations = [
+                    get_domain_from_id(pdu.content["join_authorised_via_users_server"])
+                ]
 
         return await self._try_destination_list("send_join", destinations, send_request)
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 29619aeeb8..2892a11d7d 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
     UnsupportedRoomVersionError,
 )
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.crypto.event_signing import compute_event_signature
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
@@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
     ReplicationGetQueryRestServlet,
 )
 from synapse.storage.databases.main.lock import Lock
-from synapse.types import JsonDict
+from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
@@ -586,7 +587,7 @@ class FederationServer(FederationBase):
     async def on_send_join_request(
         self, origin: str, content: JsonDict, room_id: str
     ) -> Dict[str, Any]:
-        context = await self._on_send_membership_event(
+        event, context = await self._on_send_membership_event(
             origin, content, Membership.JOIN, room_id
         )
 
@@ -597,6 +598,7 @@ class FederationServer(FederationBase):
 
         time_now = self._clock.time_msec()
         return {
+            "org.matrix.msc3083.v2.event": event.get_pdu_json(),
             "state": [p.get_pdu_json(time_now) for p in state.values()],
             "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
         }
@@ -681,7 +683,7 @@ class FederationServer(FederationBase):
         Returns:
             The stripped room state.
         """
-        event_context = await self._on_send_membership_event(
+        _, context = await self._on_send_membership_event(
             origin, content, Membership.KNOCK, room_id
         )
 
@@ -690,14 +692,14 @@ class FederationServer(FederationBase):
         # related to the room while the knock request is pending.
         stripped_room_state = (
             await self.store.get_stripped_room_state_from_event_context(
-                event_context, self._room_prejoin_state_types
+                context, self._room_prejoin_state_types
             )
         )
         return {"knock_state_events": stripped_room_state}
 
     async def _on_send_membership_event(
         self, origin: str, content: JsonDict, membership_type: str, room_id: str
-    ) -> EventContext:
+    ) -> Tuple[EventBase, EventContext]:
         """Handle an on_send_{join,leave,knock} request
 
         Does some preliminary validation before passing the request on to the
@@ -712,7 +714,7 @@ class FederationServer(FederationBase):
                 in the event
 
         Returns:
-            The context of the event after inserting it into the room graph.
+            The event and context of the event after inserting it into the room graph.
 
         Raises:
             SynapseError if there is a problem with the request, including things like
@@ -748,6 +750,33 @@ class FederationServer(FederationBase):
 
         logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
 
+        # Sign the event since we're vouching on behalf of the remote server that
+        # the event is valid to be sent into the room. Currently this is only done
+        # if the user is being joined via restricted join rules.
+        if (
+            room_version.msc3083_join_rules
+            and event.membership == Membership.JOIN
+            and "join_authorised_via_users_server" in event.content
+        ):
+            # We can only authorise our own users.
+            authorising_server = get_domain_from_id(
+                event.content["join_authorised_via_users_server"]
+            )
+            if authorising_server != self.server_name:
+                raise SynapseError(
+                    400,
+                    f"Cannot authorise request from resident server: {authorising_server}",
+                )
+
+            event.signatures.update(
+                compute_event_signature(
+                    room_version,
+                    event.get_pdu_json(),
+                    self.hs.hostname,
+                    self.hs.signing_key,
+                )
+            )
+
         event = await self._check_sigs_and_hash(room_version, event)
 
         return await self.handler.on_send_membership_event(origin, event)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index e73bdb52b3..6a8d3ad4fe 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1219,8 +1219,26 @@ def _create_v2_path(path: str, *args: str) -> str:
 class SendJoinResponse:
     """The parsed response of a `/send_join` request."""
 
+    # The list of auth events from the /send_join response.
     auth_events: List[EventBase]
+    # The list of state from the /send_join response.
     state: List[EventBase]
+    # The raw join event from the /send_join response.
+    event_dict: JsonDict
+    # The parsed join event from the /send_join response. This will be None if
+    # "event" is not included in the response.
+    event: Optional[EventBase] = None
+
+
+@ijson.coroutine
+def _event_parser(event_dict: JsonDict):
+    """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
+    to add them to a given dictionary.
+    """
+
+    while True:
+        key, value = yield
+        event_dict[key] = value
 
 
 @ijson.coroutine
@@ -1246,7 +1264,8 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
     CONTENT_TYPE = "application/json"
 
     def __init__(self, room_version: RoomVersion, v1_api: bool):
-        self._response = SendJoinResponse([], [])
+        self._response = SendJoinResponse([], [], {})
+        self._room_version = room_version
 
         # The V1 API has the shape of `[200, {...}]`, which we handle by
         # prefixing with `item.*`.
@@ -1260,12 +1279,21 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
             _event_list_parser(room_version, self._response.auth_events),
             prefix + "auth_chain.item",
         )
+        self._coro_event = ijson.kvitems_coro(
+            _event_parser(self._response.event_dict),
+            prefix + "org.matrix.msc3083.v2.event",
+        )
 
     def write(self, data: bytes) -> int:
         self._coro_state.send(data)
         self._coro_auth.send(data)
+        self._coro_event.send(data)
 
         return len(data)
 
     def finish(self) -> SendJoinResponse:
+        if self._response.event_dict:
+            self._response.event = make_event_from_dict(
+                self._response.event_dict, self._room_version
+            )
         return self._response
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 41dbdfd0a1..53fac1f8a3 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 from typing import TYPE_CHECKING, Collection, List, Optional, Union
 
 from synapse import event_auth
@@ -20,16 +21,18 @@ from synapse.api.constants import (
     Membership,
     RestrictedJoinRuleTypes,
 )
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, Codes, SynapseError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.events.builder import EventBuilder
-from synapse.types import StateMap
+from synapse.types import StateMap, get_domain_from_id
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+logger = logging.getLogger(__name__)
+
 
 class EventAuthHandler:
     """
@@ -39,6 +42,7 @@ class EventAuthHandler:
     def __init__(self, hs: "HomeServer"):
         self._clock = hs.get_clock()
         self._store = hs.get_datastore()
+        self._server_name = hs.hostname
 
     async def check_from_context(
         self, room_version: str, event, context, do_sig_check=True
@@ -81,15 +85,76 @@ class EventAuthHandler:
         # introduce undesirable "state reset" behaviour.
         #
         # All of which sounds a bit tricky so we don't bother for now.
-
         auth_ids = []
-        for etype, state_key in event_auth.auth_types_for_event(event):
+        for etype, state_key in event_auth.auth_types_for_event(
+            event.room_version, event
+        ):
             auth_ev_id = current_state_ids.get((etype, state_key))
             if auth_ev_id:
                 auth_ids.append(auth_ev_id)
 
         return auth_ids
 
+    async def get_user_which_could_invite(
+        self, room_id: str, current_state_ids: StateMap[str]
+    ) -> str:
+        """
+        Searches the room state for a local user who has the power level necessary
+        to invite other users.
+
+        Args:
+            room_id: The room ID under search.
+            current_state_ids: The current state of the room.
+
+        Returns:
+            The MXID of the user which could issue an invite.
+
+        Raises:
+            SynapseError if no appropriate user is found.
+        """
+        power_level_event_id = current_state_ids.get((EventTypes.PowerLevels, ""))
+        invite_level = 0
+        users_default_level = 0
+        if power_level_event_id:
+            power_level_event = await self._store.get_event(power_level_event_id)
+            invite_level = power_level_event.content.get("invite", invite_level)
+            users_default_level = power_level_event.content.get(
+                "users_default", users_default_level
+            )
+            users = power_level_event.content.get("users", {})
+        else:
+            users = {}
+
+        # Find the user with the highest power level.
+        users_in_room = await self._store.get_users_in_room(room_id)
+        # Only interested in local users.
+        local_users_in_room = [
+            u for u in users_in_room if get_domain_from_id(u) == self._server_name
+        ]
+        chosen_user = max(
+            local_users_in_room,
+            key=lambda user: users.get(user, users_default_level),
+            default=None,
+        )
+
+        # Return the chosen if they can issue invites.
+        user_power_level = users.get(chosen_user, users_default_level)
+        if chosen_user and user_power_level >= invite_level:
+            logger.debug(
+                "Found a user who can issue invites  %s with power level %d >= invite level %d",
+                chosen_user,
+                user_power_level,
+                invite_level,
+            )
+            return chosen_user
+
+        # No user was found.
+        raise SynapseError(
+            400,
+            "Unable to find a user which could issue an invite",
+            Codes.UNABLE_TO_GRANT_JOIN,
+        )
+
     async def check_host_in_room(self, room_id: str, host: str) -> bool:
         with Measure(self._clock, "check_host_in_room"):
             return await self._store.is_host_joined(room_id, host)
@@ -134,6 +199,18 @@ class EventAuthHandler:
         # in any of them.
         allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
         if not await self.is_user_in_rooms(allowed_rooms, user_id):
+
+            # If this is a remote request, the user might be in an allowed room
+            # that we do not know about.
+            if get_domain_from_id(user_id) != self._server_name:
+                for room_id in allowed_rooms:
+                    if not await self._store.is_host_joined(room_id, self._server_name):
+                        raise SynapseError(
+                            400,
+                            f"Unable to check if {user_id} is in allowed rooms.",
+                            Codes.UNABLE_AUTHORISE_JOIN,
+                        )
+
             raise AuthError(
                 403,
                 "You do not belong to any of the required rooms to join this room.",
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 5728719909..aba095d2e1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1494,9 +1494,10 @@ class FederationHandler(BaseHandler):
                 host_list, event, room_version_obj
             )
 
-            origin = ret["origin"]
-            state = ret["state"]
-            auth_chain = ret["auth_chain"]
+            event = ret.event
+            origin = ret.origin
+            state = ret.state
+            auth_chain = ret.auth_chain
             auth_chain.sort(key=lambda e: e.depth)
 
             logger.debug("do_invite_join auth_chain: %s", auth_chain)
@@ -1676,7 +1677,7 @@ class FederationHandler(BaseHandler):
 
         # checking the room version will check that we've actually heard of the room
         # (and return a 404 otherwise)
-        room_version = await self.store.get_room_version_id(room_id)
+        room_version = await self.store.get_room_version(room_id)
 
         # now check that we are *still* in the room
         is_in_room = await self._event_auth_handler.check_host_in_room(
@@ -1691,8 +1692,38 @@ class FederationHandler(BaseHandler):
 
         event_content = {"membership": Membership.JOIN}
 
+        # If the current room is using restricted join rules, additional information
+        # may need to be included in the event content in order to efficiently
+        # validate the event.
+        #
+        # Note that this requires the /send_join request to come back to the
+        # same server.
+        if room_version.msc3083_join_rules:
+            state_ids = await self.store.get_current_state_ids(room_id)
+            if await self._event_auth_handler.has_restricted_join_rules(
+                state_ids, room_version
+            ):
+                prev_member_event_id = state_ids.get((EventTypes.Member, user_id), None)
+                # If the user is invited or joined to the room already, then
+                # no additional info is needed.
+                include_auth_user_id = True
+                if prev_member_event_id:
+                    prev_member_event = await self.store.get_event(prev_member_event_id)
+                    include_auth_user_id = prev_member_event.membership not in (
+                        Membership.JOIN,
+                        Membership.INVITE,
+                    )
+
+                if include_auth_user_id:
+                    event_content[
+                        "join_authorised_via_users_server"
+                    ] = await self._event_auth_handler.get_user_which_could_invite(
+                        room_id,
+                        state_ids,
+                    )
+
         builder = self.event_builder_factory.new(
-            room_version,
+            room_version.identifier,
             {
                 "type": EventTypes.Member,
                 "content": event_content,
@@ -1710,10 +1741,13 @@ class FederationHandler(BaseHandler):
             logger.warning("Failed to create join to %s because %s", room_id, e)
             raise
 
+        # Ensure the user can even join the room.
+        await self._check_join_restrictions(context, event)
+
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
         await self._event_auth_handler.check_from_context(
-            room_version, event, context, do_sig_check=False
+            room_version.identifier, event, context, do_sig_check=False
         )
 
         return event
@@ -1958,7 +1992,7 @@ class FederationHandler(BaseHandler):
     @log_function
     async def on_send_membership_event(
         self, origin: str, event: EventBase
-    ) -> EventContext:
+    ) -> Tuple[EventBase, EventContext]:
         """
         We have received a join/leave/knock event for a room via send_join/leave/knock.
 
@@ -1981,7 +2015,7 @@ class FederationHandler(BaseHandler):
             event: The member event that has been signed by the remote homeserver.
 
         Returns:
-            The context of the event after inserting it into the room graph.
+            The event and context of the event after inserting it into the room graph.
 
         Raises:
             SynapseError if the event is not accepted into the room
@@ -2037,7 +2071,7 @@ class FederationHandler(BaseHandler):
 
         # all looks good, we can persist the event.
         await self._run_push_actions_and_persist_event(event, context)
-        return context
+        return event, context
 
     async def _check_join_restrictions(
         self, context: EventContext, event: EventBase
@@ -2473,7 +2507,7 @@ class FederationHandler(BaseHandler):
         )
 
         # Now check if event pass auth against said current state
-        auth_types = auth_types_for_event(event)
+        auth_types = auth_types_for_event(room_version_obj, event)
         current_state_ids_list = [
             e for k, e in current_state_ids.items() if k in auth_types
         ]
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1192591609..65ad3efa6a 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -16,7 +16,7 @@ import abc
 import logging
 import random
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
 
 from synapse import types
 from synapse.api.constants import AccountDataTypes, EventTypes, Membership
@@ -28,6 +28,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.event_auth import get_named_level, get_power_level_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.types import (
@@ -340,16 +341,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         if event.membership == Membership.JOIN:
             newly_joined = True
-            prev_member_event = None
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 newly_joined = prev_member_event.membership != Membership.JOIN
 
-            # Check if the member should be allowed access via membership in a space.
-            await self.event_auth_handler.check_restricted_join_rules(
-                prev_state_ids, event.room_version, user_id, prev_member_event
-            )
-
             # Only rate-limit if the user actually joined the room, otherwise we'll end
             # up blocking profile updates.
             if newly_joined and ratelimit:
@@ -701,7 +696,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     # so don't really fit into the general auth process.
                     raise AuthError(403, "Guest access not allowed")
 
-            if not is_host_in_room:
+            # Check if a remote join should be performed.
+            remote_join, remote_room_hosts = await self._should_perform_remote_join(
+                target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
+            )
+            if remote_join:
                 if ratelimit:
                     time_now_s = self.clock.time()
                     (
@@ -826,6 +825,106 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             outlier=outlier,
         )
 
+    async def _should_perform_remote_join(
+        self,
+        user_id: str,
+        room_id: str,
+        remote_room_hosts: List[str],
+        content: JsonDict,
+        is_host_in_room: bool,
+    ) -> Tuple[bool, List[str]]:
+        """
+        Check whether the server should do a remote join (as opposed to a local
+        join) for a user.
+
+        Generally a remote join is used if:
+
+        * The server is not yet in the room.
+        * The server is in the room, the room has restricted join rules, the user
+          is not joined or invited to the room, and the server does not have
+          another user who is capable of issuing invites.
+
+        Args:
+            user_id: The user joining the room.
+            room_id: The room being joined.
+            remote_room_hosts: A list of remote room hosts.
+            content: The content to use as the event body of the join. This may
+                be modified.
+            is_host_in_room: True if the host is in the room.
+
+        Returns:
+            A tuple of:
+                True if a remote join should be performed. False if the join can be
+                done locally.
+
+                A list of remote room hosts to use. This is an empty list if a
+                local join is to be done.
+        """
+        # If the host isn't in the room, pass through the prospective hosts.
+        if not is_host_in_room:
+            return True, remote_room_hosts
+
+        # If the host is in the room, but not one of the authorised hosts
+        # for restricted join rules, a remote join must be used.
+        room_version = await self.store.get_room_version(room_id)
+        current_state_ids = await self.store.get_current_state_ids(room_id)
+
+        # If restricted join rules are not being used, a local join can always
+        # be used.
+        if not await self.event_auth_handler.has_restricted_join_rules(
+            current_state_ids, room_version
+        ):
+            return False, []
+
+        # If the user is invited to the room or already joined, the join
+        # event can always be issued locally.
+        prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
+        prev_member_event = None
+        if prev_member_event_id:
+            prev_member_event = await self.store.get_event(prev_member_event_id)
+            if prev_member_event.membership in (
+                Membership.JOIN,
+                Membership.INVITE,
+            ):
+                return False, []
+
+        # If the local host has a user who can issue invites, then a local
+        # join can be done.
+        #
+        # If not, generate a new list of remote hosts based on which
+        # can issue invites.
+        event_map = await self.store.get_events(current_state_ids.values())
+        current_state = {
+            state_key: event_map[event_id]
+            for state_key, event_id in current_state_ids.items()
+        }
+        allowed_servers = get_servers_from_users(
+            get_users_which_can_issue_invite(current_state)
+        )
+
+        # If the local server is not one of allowed servers, then a remote
+        # join must be done. Return the list of prospective servers based on
+        # which can issue invites.
+        if self.hs.hostname not in allowed_servers:
+            return True, list(allowed_servers)
+
+        # Ensure the member should be allowed access via membership in a room.
+        await self.event_auth_handler.check_restricted_join_rules(
+            current_state_ids, room_version, user_id, prev_member_event
+        )
+
+        # If this is going to be a local join, additional information must
+        # be included in the event content in order to efficiently validate
+        # the event.
+        content[
+            "join_authorised_via_users_server"
+        ] = await self.event_auth_handler.get_user_which_could_invite(
+            room_id,
+            current_state_ids,
+        )
+
+        return False, []
+
     async def transfer_room_state_on_room_upgrade(
         self, old_room_id: str, room_id: str
     ) -> None:
@@ -1514,3 +1613,63 @@ class RoomMemberMasterHandler(RoomMemberHandler):
 
         if membership:
             await self.store.forget(user_id, room_id)
+
+
+def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]:
+    """
+    Return the list of users which can issue invites.
+
+    This is done by exploring the joined users and comparing their power levels
+    to the necessyar power level to issue an invite.
+
+    Args:
+        auth_events: state in force at this point in the room
+
+    Returns:
+        The users which can issue invites.
+    """
+    invite_level = get_named_level(auth_events, "invite", 0)
+    users_default_level = get_named_level(auth_events, "users_default", 0)
+    power_level_event = get_power_level_event(auth_events)
+
+    # Custom power-levels for users.
+    if power_level_event:
+        users = power_level_event.content.get("users", {})
+    else:
+        users = {}
+
+    result = []
+
+    # Check which members are able to invite by ensuring they're joined and have
+    # the necessary power level.
+    for (event_type, state_key), event in auth_events.items():
+        if event_type != EventTypes.Member:
+            continue
+
+        if event.membership != Membership.JOIN:
+            continue
+
+        # Check if the user has a custom power level.
+        if users.get(state_key, users_default_level) >= invite_level:
+            result.append(state_key)
+
+    return result
+
+
+def get_servers_from_users(users: List[str]) -> Set[str]:
+    """
+    Resolve a list of users into their servers.
+
+    Args:
+        users: A list of users.
+
+    Returns:
+        A set of servers.
+    """
+    servers = set()
+    for user in users:
+        try:
+            servers.add(get_domain_from_id(user))
+        except SynapseError:
+            pass
+    return servers
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 6223daf522..2e15471435 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -636,16 +636,20 @@ class StateResolutionHandler:
         """
         try:
             with Measure(self.clock, "state._resolve_events") as m:
-                v = KNOWN_ROOM_VERSIONS[room_version]
-                if v.state_res == StateResolutionVersions.V1:
+                room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+                if room_version_obj.state_res == StateResolutionVersions.V1:
                     return await v1.resolve_events_with_store(
-                        room_id, state_sets, event_map, state_res_store.get_events
+                        room_id,
+                        room_version_obj,
+                        state_sets,
+                        event_map,
+                        state_res_store.get_events,
                     )
                 else:
                     return await v2.resolve_events_with_store(
                         self.clock,
                         room_id,
-                        room_version,
+                        room_version_obj,
                         state_sets,
                         event_map,
                         state_res_store,
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 267193cedf..92336d7cc8 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -29,7 +29,7 @@ from typing import (
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 
@@ -41,6 +41,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 
 async def resolve_events_with_store(
     room_id: str,
+    room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
@@ -104,7 +105,7 @@ async def resolve_events_with_store(
     # get the ids of the auth events which allow us to authenticate the
     # conflicted state, picking only from the unconflicting state.
     auth_events = _create_auth_events_from_maps(
-        unconflicted_state, conflicted_state, state_map
+        room_version, unconflicted_state, conflicted_state, state_map
     )
 
     new_needed_events = set(auth_events.values())
@@ -132,7 +133,7 @@ async def resolve_events_with_store(
     state_map.update(state_map_new)
 
     return _resolve_with_state(
-        unconflicted_state, conflicted_state, auth_events, state_map
+        room_version, unconflicted_state, conflicted_state, auth_events, state_map
     )
 
 
@@ -187,6 +188,7 @@ def _seperate(
 
 
 def _create_auth_events_from_maps(
+    room_version: RoomVersion,
     unconflicted_state: StateMap[str],
     conflicted_state: StateMap[Set[str]],
     state_map: Dict[str, EventBase],
@@ -194,6 +196,7 @@ def _create_auth_events_from_maps(
     """
 
     Args:
+        room_version: The room version.
         unconflicted_state: The unconflicted state map.
         conflicted_state: The conflicted state map.
         state_map:
@@ -205,7 +208,9 @@ def _create_auth_events_from_maps(
     for event_ids in conflicted_state.values():
         for event_id in event_ids:
             if event_id in state_map:
-                keys = event_auth.auth_types_for_event(state_map[event_id])
+                keys = event_auth.auth_types_for_event(
+                    room_version, state_map[event_id]
+                )
                 for key in keys:
                     if key not in auth_events:
                         auth_event_id = unconflicted_state.get(key, None)
@@ -215,6 +220,7 @@ def _create_auth_events_from_maps(
 
 
 def _resolve_with_state(
+    room_version: RoomVersion,
     unconflicted_state_ids: MutableStateMap[str],
     conflicted_state_ids: StateMap[Set[str]],
     auth_event_ids: StateMap[str],
@@ -235,7 +241,9 @@ def _resolve_with_state(
     }
 
     try:
-        resolved_state = _resolve_state_events(conflicted_state, auth_events)
+        resolved_state = _resolve_state_events(
+            room_version, conflicted_state, auth_events
+        )
     except Exception:
         logger.exception("Failed to resolve state")
         raise
@@ -248,7 +256,9 @@ def _resolve_with_state(
 
 
 def _resolve_state_events(
-    conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
+    room_version: RoomVersion,
+    conflicted_state: StateMap[List[EventBase]],
+    auth_events: MutableStateMap[EventBase],
 ) -> StateMap[EventBase]:
     """This is where we actually decide which of the conflicted state to
     use.
@@ -263,21 +273,27 @@ def _resolve_state_events(
     if POWER_KEY in conflicted_state:
         events = conflicted_state[POWER_KEY]
         logger.debug("Resolving conflicted power levels %r", events)
-        resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events)
+        resolved_state[POWER_KEY] = _resolve_auth_events(
+            room_version, events, auth_events
+        )
 
     auth_events.update(resolved_state)
 
     for key, events in conflicted_state.items():
         if key[0] == EventTypes.JoinRules:
             logger.debug("Resolving conflicted join rules %r", events)
-            resolved_state[key] = _resolve_auth_events(events, auth_events)
+            resolved_state[key] = _resolve_auth_events(
+                room_version, events, auth_events
+            )
 
     auth_events.update(resolved_state)
 
     for key, events in conflicted_state.items():
         if key[0] == EventTypes.Member:
             logger.debug("Resolving conflicted member lists %r", events)
-            resolved_state[key] = _resolve_auth_events(events, auth_events)
+            resolved_state[key] = _resolve_auth_events(
+                room_version, events, auth_events
+            )
 
     auth_events.update(resolved_state)
 
@@ -290,12 +306,14 @@ def _resolve_state_events(
 
 
 def _resolve_auth_events(
-    events: List[EventBase], auth_events: StateMap[EventBase]
+    room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase]
 ) -> EventBase:
     reverse = list(reversed(_ordered_events(events)))
 
     auth_keys = {
-        key for event in events for key in event_auth.auth_types_for_event(event)
+        key
+        for event in events
+        for key in event_auth.auth_types_for_event(room_version, event)
     }
 
     new_auth_events = {}
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index e66e6571c8..7b1e8361de 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -36,7 +36,7 @@ import synapse.state
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 from synapse.util import Clock
@@ -53,7 +53,7 @@ _AWAIT_AFTER_ITERATIONS = 100
 async def resolve_events_with_store(
     clock: Clock,
     room_id: str,
-    room_version: str,
+    room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "synapse.state.StateResolutionStore",
@@ -497,7 +497,7 @@ async def _reverse_topological_power_sort(
 async def _iterative_auth_checks(
     clock: Clock,
     room_id: str,
-    room_version: str,
+    room_version: RoomVersion,
     event_ids: List[str],
     base_state: StateMap[str],
     event_map: Dict[str, EventBase],
@@ -519,7 +519,6 @@ async def _iterative_auth_checks(
         Returns the final updated state
     """
     resolved_state = dict(base_state)
-    room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
     for idx, event_id in enumerate(event_ids, start=1):
         event = event_map[event_id]
@@ -538,7 +537,7 @@ async def _iterative_auth_checks(
                 if ev.rejected_reason is None:
                     auth_events[(ev.type, ev.state_key)] = ev
 
-        for key in event_auth.auth_types_for_event(event):
+        for key in event_auth.auth_types_for_event(room_version, event):
             if key in resolved_state:
                 ev_id = resolved_state[key]
                 ev = await _get_event(room_id, ev_id, event_map, state_res_store)
@@ -548,7 +547,7 @@ async def _iterative_auth_checks(
 
         try:
             event_auth.check(
-                room_version_obj,
+                room_version,
                 event,
                 auth_events,
                 do_sig_check=False,