summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/14606.misc1
-rw-r--r--synapse/api/errors.py22
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/handlers/event_auth.py16
-rw-r--r--synapse/handlers/federation.py2
-rw-r--r--synapse/handlers/federation_event.py59
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/room_member.py118
-rw-r--r--synapse/handlers/room_member_worker.py5
-rw-r--r--synapse/storage/databases/main/events.py21
-rw-r--r--tests/handlers/test_federation.py40
12 files changed, 196 insertions, 94 deletions
diff --git a/changelog.d/14606.misc b/changelog.d/14606.misc
new file mode 100644
index 0000000000..e2debc96d8
--- /dev/null
+++ b/changelog.d/14606.misc
@@ -0,0 +1 @@
+Faster joins: don't stall when another user joins during a fast join resync.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index c2c177fd71..9235ce6536 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -751,3 +751,25 @@ class ModuleFailedException(Exception):
     Raised when a module API callback fails, for example because it raised an
     exception.
     """
+
+
+class PartialStateConflictError(SynapseError):
+    """An internal error raised when attempting to persist an event with partial state
+    after the room containing the event has been un-partial stated.
+
+    This error should be handled by recomputing the event context and trying again.
+
+    This error has an HTTP status code so that it can be transported over replication.
+    It should not be exposed to clients.
+    """
+
+    @staticmethod
+    def message() -> str:
+        return "Cannot persist partial state event in un-partial stated room"
+
+    def __init__(self) -> None:
+        super().__init__(
+            HTTPStatus.CONFLICT,
+            msg=PartialStateConflictError.message(),
+            errcode=Codes.UNKNOWN,
+        )
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 6addc0bb65..6d99845de5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -48,6 +48,7 @@ from synapse.api.errors import (
     FederationError,
     IncompatibleRoomVersionError,
     NotFoundError,
+    PartialStateConflictError,
     SynapseError,
     UnsupportedRoomVersionError,
 )
@@ -81,7 +82,6 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEduRestServlet,
     ReplicationGetQueryRestServlet,
 )
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.lock import Lock
 from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
 from synapse.storage.roommember import MemberSummary
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a23a8ce2a1..46dd63c3f0 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -202,7 +202,7 @@ class EventAuthHandler:
         state_ids: StateMap[str],
         room_version: RoomVersion,
         user_id: str,
-        prev_member_event: Optional[EventBase],
+        prev_membership: Optional[str],
     ) -> None:
         """
         Check whether a user can join a room without an invite due to restricted join rules.
@@ -214,15 +214,14 @@ class EventAuthHandler:
             state_ids: The state of the room as it currently is.
             room_version: The room version of the room being joined.
             user_id: The user joining the room.
-            prev_member_event: The current membership event for this user.
+            prev_membership: The current membership state for this user. `None` if the
+                user has never joined the room (equivalent to "leave").
 
         Raises:
             AuthError if the user cannot join the room.
         """
         # If the member is invited or currently joined, then nothing to do.
-        if prev_member_event and (
-            prev_member_event.membership in (Membership.JOIN, Membership.INVITE)
-        ):
+        if prev_membership in (Membership.JOIN, Membership.INVITE):
             return
 
         # This is not a room with a restricted join rule, so we don't need to do the
@@ -255,13 +254,14 @@ class EventAuthHandler:
             )
 
     async def has_restricted_join_rules(
-        self, state_ids: StateMap[str], room_version: RoomVersion
+        self, partial_state_ids: StateMap[str], room_version: RoomVersion
     ) -> bool:
         """
         Return if the room has the proper join rules set for access via rooms.
 
         Args:
-            state_ids: The state of the room as it currently is.
+            state_ids: The state of the room as it currently is. May be full or partial
+                state.
             room_version: The room version of the room to query.
 
         Returns:
@@ -272,7 +272,7 @@ class EventAuthHandler:
             return False
 
         # If there's no join rule, then it defaults to invite (so this doesn't apply).
-        join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+        join_rules_event_id = partial_state_ids.get((EventTypes.JoinRules, ""), None)
         if not join_rules_event_id:
             return False
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 43ed4a3dd1..08727e4857 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -49,6 +49,7 @@ from synapse.api.errors import (
     FederationPullAttemptBackoffError,
     HttpResponseException,
     NotFoundError,
+    PartialStateConflictError,
     RequestSendFailed,
     SynapseError,
 )
@@ -68,7 +69,6 @@ from synapse.replication.http.federation import (
     ReplicationCleanRoomRestServlet,
     ReplicationStoreRoomOnOutlierMembershipRestServlet,
 )
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import JsonDict, StrCollection, get_domain_from_id
 from synapse.types.state import StateFilter
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 3561f2f1de..b7136f8d1c 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -47,6 +47,7 @@ from synapse.api.errors import (
     FederationError,
     FederationPullAttemptBackoffError,
     HttpResponseException,
+    PartialStateConflictError,
     RequestSendFailed,
     SynapseError,
 )
@@ -74,7 +75,6 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEventsRestServlet,
 )
 from synapse.state import StateResolutionStore
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
     PersistedEventPosition,
@@ -441,16 +441,17 @@ class FederationEventHandler:
         # Check if the user is already in the room or invited to the room.
         user_id = event.state_key
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
-        prev_member_event = None
+        prev_membership = None
         if prev_member_event_id:
             prev_member_event = await self._store.get_event(prev_member_event_id)
+            prev_membership = prev_member_event.membership
 
         # Check if the member should be allowed access via membership in a space.
         await self._event_auth_handler.check_restricted_join_rules(
             prev_state_ids,
             event.room_version,
             user_id,
-            prev_member_event,
+            prev_membership,
         )
 
     @trace
@@ -526,11 +527,57 @@ class FederationEventHandler:
             "Peristing join-via-remote %s (partial_state: %s)", event, partial_state
         )
         with nested_logging_context(suffix=event.event_id):
+            if partial_state:
+                # When handling a second partial state join into a partial state room,
+                # the returned state will exclude the membership from the first join. To
+                # preserve prior memberships, we try to compute the partial state before
+                # the event ourselves if we know about any of the prev events.
+                #
+                # When we don't know about any of the prev events, it's fine to just use
+                # the returned state, since the new join will create a new forward
+                # extremity, and leave the forward extremity containing our prior
+                # memberships alone.
+                prev_event_ids = set(event.prev_event_ids())
+                seen_event_ids = await self._store.have_events_in_timeline(
+                    prev_event_ids
+                )
+                missing_event_ids = prev_event_ids - seen_event_ids
+
+                state_maps_to_resolve: List[StateMap[str]] = []
+
+                # Fetch the state after the prev events that we know about.
+                state_maps_to_resolve.extend(
+                    (
+                        await self._state_storage_controller.get_state_groups_ids(
+                            room_id, seen_event_ids, await_full_state=False
+                        )
+                    ).values()
+                )
+
+                # When there are prev events we do not have the state for, we state
+                # resolve with the state returned by the remote homeserver.
+                if missing_event_ids or len(state_maps_to_resolve) == 0:
+                    state_maps_to_resolve.append(
+                        {(e.type, e.state_key): e.event_id for e in state}
+                    )
+
+                state_ids_before_event = (
+                    await self._state_resolution_handler.resolve_events_with_store(
+                        event.room_id,
+                        room_version.identifier,
+                        state_maps_to_resolve,
+                        event_map=None,
+                        state_res_store=StateResolutionStore(self._store),
+                    )
+                )
+            else:
+                state_ids_before_event = {
+                    (e.type, e.state_key): e.event_id for e in state
+                }
+
             context = await self._state_handler.compute_event_context(
                 event,
-                state_ids_before_event={
-                    (e.type, e.state_key): e.event_id for e in state
-                },
+                state_ids_before_event=state_ids_before_event,
                 partial_state=partial_state,
             )
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 3e30f52e4d..8f5b658d9d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
     Codes,
     ConsentNotGivenError,
     NotFoundError,
+    PartialStateConflictError,
     ShadowBanError,
     SynapseError,
     UnstableSpecAuthError,
@@ -57,7 +58,6 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
     MutableStateMap,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 060bbcb181..837dabb3b7 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -43,6 +43,7 @@ from synapse.api.errors import (
     Codes,
     LimitExceededError,
     NotFoundError,
+    PartialStateConflictError,
     StoreError,
     SynapseError,
 )
@@ -54,7 +55,6 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
 from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.streams import EventSource
 from synapse.types import (
     JsonDict,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 6e7141d2ef..a965c7ec76 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -26,7 +26,13 @@ from synapse.api.constants import (
     GuestAccess,
     Membership,
 )
-from synapse.api.errors import AuthError, Codes, ShadowBanError, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    PartialStateConflictError,
+    ShadowBanError,
+    SynapseError,
+)
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.event_auth import get_named_level, get_power_level_event
 from synapse.events import EventBase
@@ -34,7 +40,6 @@ from synapse.events.snapshot import EventContext
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
 from synapse.logging import opentracing
 from synapse.module_api import NOT_SPAM
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.types import (
     JsonDict,
     Requester,
@@ -56,6 +61,13 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class NoKnownServersError(SynapseError):
+    """No server already resident to the room was provided to the join/knock operation."""
+
+    def __init__(self, msg: str = "No known servers"):
+        super().__init__(404, msg)
+
+
 class RoomMemberHandler(metaclass=abc.ABCMeta):
     # TODO(paul): This handler currently contains a messy conflation of
     #   low-level API that works on UserID objects and so on, and REST-level
@@ -185,6 +197,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             room_id: Room that we are trying to join
             user: User who is trying to join
             content: A dict that should be used as the content of the join event.
+
+        Raises:
+            NoKnownServersError: if remote_room_hosts does not contain a server joined to
+                the room.
         """
         raise NotImplementedError()
 
@@ -823,14 +839,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
 
-        state_before_join = await self.state_handler.compute_state_after_events(
-            room_id, latest_event_ids
+        is_partial_state_room = await self.store.is_partial_state_room(room_id)
+        partial_state_before_join = await self.state_handler.compute_state_after_events(
+            room_id, latest_event_ids, await_full_state=False
         )
+        # `is_partial_state_room` also indicates whether `partial_state_before_join` is
+        # partial.
 
         # TODO: Refactor into dictionary of explicitly allowed transitions
         # between old and new state, with specific error messages for some
         # transitions and generic otherwise
-        old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
+        old_state_id = partial_state_before_join.get(
+            (EventTypes.Member, target.to_string())
+        )
         if old_state_id:
             old_state = await self.store.get_event(old_state_id, allow_none=True)
             old_membership = old_state.content.get("membership") if old_state else None
@@ -881,11 +902,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if action == "kick":
                 raise AuthError(403, "The target user is not in the room")
 
-        is_host_in_room = await self._is_host_in_room(state_before_join)
+        is_host_in_room = await self._is_host_in_room(partial_state_before_join)
 
         if effective_membership_state == Membership.JOIN:
             if requester.is_guest:
-                guest_can_join = await self._can_guest_join(state_before_join)
+                guest_can_join = await self._can_guest_join(partial_state_before_join)
                 if not guest_can_join:
                     # This should be an auth check, but guests are a local concept,
                     # so don't really fit into the general auth process.
@@ -927,8 +948,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 room_id,
                 remote_room_hosts,
                 content,
+                is_partial_state_room,
                 is_host_in_room,
-                state_before_join,
+                partial_state_before_join,
             )
             if remote_join:
                 if ratelimit:
@@ -1073,8 +1095,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         room_id: str,
         remote_room_hosts: List[str],
         content: JsonDict,
+        is_partial_state_room: bool,
         is_host_in_room: bool,
-        state_before_join: StateMap[str],
+        partial_state_before_join: StateMap[str],
     ) -> Tuple[bool, List[str]]:
         """
         Check whether the server should do a remote join (as opposed to a local
@@ -1093,9 +1116,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             remote_room_hosts: A list of remote room hosts.
             content: The content to use as the event body of the join. This may
                 be modified.
-            is_host_in_room: True if the host is in the room.
-            state_before_join: The state before the join event (i.e. the resolution of
-                the states after its parent events).
+            is_partial_state_room: `True` if the server currently doesn't hold the full
+                state of the room.
+            is_host_in_room: `True` if the host is in the room.
+            partial_state_before_join: The state before the join event (i.e. the
+                resolution of the states after its parent events). May be full or
+                partial state, depending on `is_partial_state_room`.
 
         Returns:
             A tuple of:
@@ -1109,6 +1135,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         if not is_host_in_room:
             return True, remote_room_hosts
 
+        prev_member_event_id = partial_state_before_join.get(
+            (EventTypes.Member, user_id), None
+        )
+        previous_membership = None
+        if prev_member_event_id:
+            prev_member_event = await self.store.get_event(prev_member_event_id)
+            previous_membership = prev_member_event.membership
+
+        # If we are not fully joined yet, and the target is not already in the room,
+        # let's do a remote join so another server with the full state can validate
+        # that the user has not been banned for example.
+        # We could just accept the join and wait for state res to resolve that later on
+        # but we would then leak room history to this person until then, which is pretty
+        # bad.
+        if is_partial_state_room and previous_membership != Membership.JOIN:
+            return True, remote_room_hosts
+
         # If the host is in the room, but not one of the authorised hosts
         # for restricted join rules, a remote join must be used.
         room_version = await self.store.get_room_version(room_id)
@@ -1116,21 +1159,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # If restricted join rules are not being used, a local join can always
         # be used.
         if not await self.event_auth_handler.has_restricted_join_rules(
-            state_before_join, room_version
+            partial_state_before_join, room_version
         ):
             return False, []
 
         # If the user is invited to the room or already joined, the join
         # event can always be issued locally.
-        prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
-        prev_member_event = None
-        if prev_member_event_id:
-            prev_member_event = await self.store.get_event(prev_member_event_id)
-            if prev_member_event.membership in (
-                Membership.JOIN,
-                Membership.INVITE,
-            ):
-                return False, []
+        if previous_membership in (Membership.JOIN, Membership.INVITE):
+            return False, []
+
+        # All the partial state cases are covered above. We have been given the full
+        # state of the room.
+        assert not is_partial_state_room
+        state_before_join = partial_state_before_join
 
         # If the local host has a user who can issue invites, then a local
         # join can be done.
@@ -1154,7 +1195,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Ensure the member should be allowed access via membership in a room.
         await self.event_auth_handler.check_restricted_join_rules(
-            state_before_join, room_version, user_id, prev_member_event
+            state_before_join, room_version, user_id, previous_membership
         )
 
         # If this is going to be a local join, additional information must
@@ -1304,11 +1345,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 if prev_member_event.membership == Membership.JOIN:
                     await self._user_left_room(target_user, room_id)
 
-    async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
+    async def _can_guest_join(self, partial_current_state_ids: StateMap[str]) -> bool:
         """
         Returns whether a guest can join a room based on its current state.
+
+        Args:
+            partial_current_state_ids: The current state of the room. May be full or
+                partial state.
         """
-        guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
+        guest_access_id = partial_current_state_ids.get(
+            (EventTypes.GuestAccess, ""), None
+        )
         if not guest_access_id:
             return False
 
@@ -1634,19 +1681,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         )
         return event, stream_id
 
-    async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
+    async def _is_host_in_room(self, partial_current_state_ids: StateMap[str]) -> bool:
+        """Returns whether the homeserver is in the room based on its current state.
+
+        Args:
+            partial_current_state_ids: The current state of the room. May be full or
+                partial state.
+        """
         # Have we just created the room, and is this about to be the very
         # first member event?
-        create_event_id = current_state_ids.get(("m.room.create", ""))
-        if len(current_state_ids) == 1 and create_event_id:
+        create_event_id = partial_current_state_ids.get(("m.room.create", ""))
+        if len(partial_current_state_ids) == 1 and create_event_id:
             # We can only get here if we're in the process of creating the room
             return True
 
-        for etype, state_key in current_state_ids:
+        for etype, state_key in partial_current_state_ids:
             if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
                 continue
 
-            event_id = current_state_ids[(etype, state_key)]
+            event_id = partial_current_state_ids[(etype, state_key)]
             event = await self.store.get_event(event_id, allow_none=True)
             if not event:
                 continue
@@ -1715,8 +1768,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]
 
         if len(remote_room_hosts) == 0:
-            raise SynapseError(
-                404,
+            raise NoKnownServersError(
                 "Can't join remote room because no servers "
                 "that are in the room have been provided.",
             )
@@ -1947,7 +1999,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]
 
         if len(remote_room_hosts) == 0:
-            raise SynapseError(404, "No known servers")
+            raise NoKnownServersError()
 
         return await self.federation_handler.do_knock(
             remote_room_hosts, room_id, user.to_string(), content=content
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 221552a2a6..ba261702d4 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -15,8 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
-from synapse.api.errors import SynapseError
-from synapse.handlers.room_member import RoomMemberHandler
+from synapse.handlers.room_member import NoKnownServersError, RoomMemberHandler
 from synapse.replication.http.membership import (
     ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
     ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
@@ -52,7 +51,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
     ) -> Tuple[str, int]:
         """Implements RoomMemberHandler._remote_join"""
         if len(remote_room_hosts) == 0:
-            raise SynapseError(404, "No known servers")
+            raise NoKnownServersError()
 
         ret = await self._remote_join_client(
             requester=requester,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index cb66376fb4..ffe766fd56 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -16,7 +16,6 @@
 import itertools
 import logging
 from collections import OrderedDict
-from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -36,7 +35,7 @@ from prometheus_client import Counter
 
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import PartialStateConflictError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
@@ -72,24 +71,6 @@ event_counter = Counter(
 )
 
 
-class PartialStateConflictError(SynapseError):
-    """An internal error raised when attempting to persist an event with partial state
-    after the room containing the event has been un-partial stated.
-
-    This error should be handled by recomputing the event context and trying again.
-
-    This error has an HTTP status code so that it can be transported over replication.
-    It should not be exposed to clients.
-    """
-
-    def __init__(self) -> None:
-        super().__init__(
-            HTTPStatus.CONFLICT,
-            msg="Cannot persist partial state event in un-partial stated room",
-            errcode=Codes.UNKNOWN,
-        )
-
-
 @attr.s(slots=True, auto_attribs=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 57675fa407..5868eb2da7 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
         fed_client = fed_handler.federation_client
 
         room_id = "!room:example.com"
-        membership_event = make_event_from_dict(
-            {
-                "room_id": room_id,
-                "type": "m.room.member",
-                "sender": "@alice:test",
-                "state_key": "@alice:test",
-                "content": {"membership": "join"},
-            },
-            RoomVersions.V10,
-        )
-
-        mock_make_membership_event = Mock(
-            return_value=make_awaitable(
-                (
-                    "example.com",
-                    membership_event,
-                    RoomVersions.V10,
-                )
-            )
-        )
 
         EVENT_CREATE = make_event_from_dict(
             {
@@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             },
             room_version=RoomVersions.V10,
         )
+        membership_event = make_event_from_dict(
+            {
+                "room_id": room_id,
+                "type": "m.room.member",
+                "sender": "@alice:test",
+                "state_key": "@alice:test",
+                "content": {"membership": "join"},
+                "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id],
+            },
+            RoomVersions.V10,
+        )
+        mock_make_membership_event = Mock(
+            return_value=make_awaitable(
+                (
+                    "example.com",
+                    membership_event,
+                    RoomVersions.V10,
+                )
+            )
+        )
         mock_send_join = Mock(
             return_value=make_awaitable(
                 SendJoinResult(