diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 2e19df0976..3a65ccbb55 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -47,6 +47,7 @@ from synapse.api.errors import (
FederationError,
FederationPullAttemptBackoffError,
HttpResponseException,
+ PartialStateConflictError,
RequestSendFailed,
SynapseError,
)
@@ -58,7 +59,7 @@ from synapse.event_auth import (
validate_event_for_room_version,
)
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
from synapse.logging.context import nested_logging_context
from synapse.logging.opentracing import (
@@ -74,7 +75,6 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet,
)
from synapse.state import StateResolutionStore
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
PersistedEventPosition,
@@ -426,7 +426,9 @@ class FederationEventHandler:
return event, context
async def check_join_restrictions(
- self, context: EventContext, event: EventBase
+ self,
+ context: UnpersistedEventContextBase,
+ event: EventBase,
) -> None:
"""Check that restrictions in restricted join rules are matched
@@ -439,16 +441,17 @@ class FederationEventHandler:
# Check if the user is already in the room or invited to the room.
user_id = event.state_key
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- prev_member_event = None
+ prev_membership = None
if prev_member_event_id:
prev_member_event = await self._store.get_event(prev_member_event_id)
+ prev_membership = prev_member_event.membership
# Check if the member should be allowed access via membership in a space.
await self._event_auth_handler.check_restricted_join_rules(
prev_state_ids,
event.room_version,
user_id,
- prev_member_event,
+ prev_membership,
)
@trace
@@ -524,11 +527,57 @@ class FederationEventHandler:
"Peristing join-via-remote %s (partial_state: %s)", event, partial_state
)
with nested_logging_context(suffix=event.event_id):
+ if partial_state:
+ # When handling a second partial state join into a partial state room,
+ # the returned state will exclude the membership from the first join. To
+ # preserve prior memberships, we try to compute the partial state before
+ # the event ourselves if we know about any of the prev events.
+ #
+ # When we don't know about any of the prev events, it's fine to just use
+ # the returned state, since the new join will create a new forward
+ # extremity, and leave the forward extremity containing our prior
+ # memberships alone.
+ prev_event_ids = set(event.prev_event_ids())
+ seen_event_ids = await self._store.have_events_in_timeline(
+ prev_event_ids
+ )
+ missing_event_ids = prev_event_ids - seen_event_ids
+
+ state_maps_to_resolve: List[StateMap[str]] = []
+
+ # Fetch the state after the prev events that we know about.
+ state_maps_to_resolve.extend(
+ (
+ await self._state_storage_controller.get_state_groups_ids(
+ room_id, seen_event_ids, await_full_state=False
+ )
+ ).values()
+ )
+
+ # When there are prev events we do not have the state for, we state
+ # resolve with the state returned by the remote homeserver.
+ if missing_event_ids or len(state_maps_to_resolve) == 0:
+ state_maps_to_resolve.append(
+ {(e.type, e.state_key): e.event_id for e in state}
+ )
+
+ state_ids_before_event = (
+ await self._state_resolution_handler.resolve_events_with_store(
+ event.room_id,
+ room_version.identifier,
+ state_maps_to_resolve,
+ event_map=None,
+ state_res_store=StateResolutionStore(self._store),
+ )
+ )
+ else:
+ state_ids_before_event = {
+ (e.type, e.state_key): e.event_id for e in state
+ }
+
context = await self._state_handler.compute_event_context(
event,
- state_ids_before_event={
- (e.type, e.state_key): e.event_id for e in state
- },
+ state_ids_before_event=state_ids_before_event,
partial_state=partial_state,
)
|