diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index aea96a0986..745750b1d7 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -14,6 +14,7 @@
import logging
from typing import cast
from unittest import TestCase
+from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
@@ -22,6 +23,7 @@ from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseErro
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json
+from synapse.federation.federation_client import SendJoinResult
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -30,7 +32,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.test_utils import event_injection
+from tests.test_utils import event_injection, make_awaitable
logger = logging.getLogger(__name__)
@@ -456,3 +458,121 @@ class EventFromPduTestCase(TestCase):
},
RoomVersions.V6,
)
+
+
+class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
+ def test_failed_partial_join_is_clean(self) -> None:
+ """
+ Tests that, when failing to partial-join a room, we don't get stuck with
+ a partial-state flag on a room.
+ """
+
+ fed_handler = self.hs.get_federation_handler()
+ fed_client = fed_handler.federation_client
+
+ room_id = "!room:example.com"
+ membership_event = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@alice:test",
+ "state_key": "@alice:test",
+ "content": {"membership": "join"},
+ },
+ RoomVersions.V10,
+ )
+
+ mock_make_membership_event = Mock(
+ return_value=make_awaitable(
+ (
+ "example.com",
+ membership_event,
+ RoomVersions.V10,
+ )
+ )
+ )
+
+ EVENT_CREATE = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.create",
+ "sender": "@kristina:example.com",
+ "state_key": "",
+ "depth": 0,
+ "content": {"creator": "@kristina:example.com", "room_version": "10"},
+ "auth_events": [],
+ "origin_server_ts": 1,
+ },
+ room_version=RoomVersions.V10,
+ )
+ EVENT_CREATOR_MEMBERSHIP = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@kristina:example.com",
+ "state_key": "@kristina:example.com",
+ "content": {"membership": "join"},
+ "depth": 1,
+ "prev_events": [EVENT_CREATE.event_id],
+ "auth_events": [EVENT_CREATE.event_id],
+ "origin_server_ts": 1,
+ },
+ room_version=RoomVersions.V10,
+ )
+ EVENT_INVITATION_MEMBERSHIP = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@kristina:example.com",
+ "state_key": "@alice:test",
+ "content": {"membership": "invite"},
+ "depth": 2,
+ "prev_events": [EVENT_CREATOR_MEMBERSHIP.event_id],
+ "auth_events": [
+ EVENT_CREATE.event_id,
+ EVENT_CREATOR_MEMBERSHIP.event_id,
+ ],
+ "origin_server_ts": 1,
+ },
+ room_version=RoomVersions.V10,
+ )
+ mock_send_join = Mock(
+ return_value=make_awaitable(
+ SendJoinResult(
+ membership_event,
+ "example.com",
+ state=[
+ EVENT_CREATE,
+ EVENT_CREATOR_MEMBERSHIP,
+ EVENT_INVITATION_MEMBERSHIP,
+ ],
+ auth_chain=[
+ EVENT_CREATE,
+ EVENT_CREATOR_MEMBERSHIP,
+ EVENT_INVITATION_MEMBERSHIP,
+ ],
+ partial_state=True,
+ servers_in_room=["example.com"],
+ )
+ )
+ )
+
+ with patch.object(
+ fed_client, "make_membership_event", mock_make_membership_event
+ ), patch.object(fed_client, "send_join", mock_send_join):
+ # Join and check that our join event is rejected
+ # (The join event is rejected because it doesn't have any signatures)
+ join_exc = self.get_failure(
+ fed_handler.do_invite_join(["example.com"], room_id, "@alice:test", {}),
+ SynapseError,
+ )
+ self.assertIn("Join event was rejected", str(join_exc))
+
+ store = self.hs.get_datastores().main
+
+ # Check that we don't have a left-over partial_state entry.
+ self.assertFalse(
+ self.get_success(store.is_partial_state_room(room_id)),
+ f"Stale partial-stated room flag left over for {room_id} after a"
+ f" failed do_invite_join!",
+ )
|