summary refs log tree commit diff
path: root/synapse/handlers/federation_event.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation_event.py')
-rw-r--r--synapse/handlers/federation_event.py59
1 files changed, 53 insertions, 6 deletions
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 3561f2f1de..b7136f8d1c 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -47,6 +47,7 @@ from synapse.api.errors import (
     FederationError,
     FederationPullAttemptBackoffError,
     HttpResponseException,
+    PartialStateConflictError,
     RequestSendFailed,
     SynapseError,
 )
@@ -74,7 +75,6 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEventsRestServlet,
 )
 from synapse.state import StateResolutionStore
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
     PersistedEventPosition,
@@ -441,16 +441,17 @@ class FederationEventHandler:
         # Check if the user is already in the room or invited to the room.
         user_id = event.state_key
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
-        prev_member_event = None
+        prev_membership = None
         if prev_member_event_id:
             prev_member_event = await self._store.get_event(prev_member_event_id)
+            prev_membership = prev_member_event.membership
 
         # Check if the member should be allowed access via membership in a space.
         await self._event_auth_handler.check_restricted_join_rules(
             prev_state_ids,
             event.room_version,
             user_id,
-            prev_member_event,
+            prev_membership,
         )
 
     @trace
@@ -526,11 +527,57 @@ class FederationEventHandler:
             "Peristing join-via-remote %s (partial_state: %s)", event, partial_state
         )
         with nested_logging_context(suffix=event.event_id):
+            if partial_state:
+                # When handling a second partial state join into a partial state room,
+                # the returned state will exclude the membership from the first join. To
+                # preserve prior memberships, we try to compute the partial state before
+                # the event ourselves if we know about any of the prev events.
+                #
+                # When we don't know about any of the prev events, it's fine to just use
+                # the returned state, since the new join will create a new forward
+                # extremity, and leave the forward extremity containing our prior
+                # memberships alone.
+                prev_event_ids = set(event.prev_event_ids())
+                seen_event_ids = await self._store.have_events_in_timeline(
+                    prev_event_ids
+                )
+                missing_event_ids = prev_event_ids - seen_event_ids
+
+                state_maps_to_resolve: List[StateMap[str]] = []
+
+                # Fetch the state after the prev events that we know about.
+                state_maps_to_resolve.extend(
+                    (
+                        await self._state_storage_controller.get_state_groups_ids(
+                            room_id, seen_event_ids, await_full_state=False
+                        )
+                    ).values()
+                )
+
+                # When there are prev events we do not have the state for, we state
+                # resolve with the state returned by the remote homeserver.
+                if missing_event_ids or len(state_maps_to_resolve) == 0:
+                    state_maps_to_resolve.append(
+                        {(e.type, e.state_key): e.event_id for e in state}
+                    )
+
+                state_ids_before_event = (
+                    await self._state_resolution_handler.resolve_events_with_store(
+                        event.room_id,
+                        room_version.identifier,
+                        state_maps_to_resolve,
+                        event_map=None,
+                        state_res_store=StateResolutionStore(self._store),
+                    )
+                )
+            else:
+                state_ids_before_event = {
+                    (e.type, e.state_key): e.event_id for e in state
+                }
+
             context = await self._state_handler.compute_event_context(
                 event,
-                state_ids_before_event={
-                    (e.type, e.state_key): e.event_id for e in state
-                },
+                state_ids_before_event=state_ids_before_event,
                 partial_state=partial_state,
             )