summary refs log tree commit diff
path: root/tests/handlers/test_federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_federation.py')
-rw-r--r--tests/handlers/test_federation.py136
1 files changed, 132 insertions, 4 deletions
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 8a0bb91f40..745750b1d7 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -14,6 +14,7 @@
 import logging
 from typing import cast
 from unittest import TestCase
+from unittest.mock import Mock, patch
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -22,6 +23,7 @@ from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseErro
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, make_event_from_dict
 from synapse.federation.federation_base import event_from_pdu_json
+from synapse.federation.federation_client import SendJoinResult
 from synapse.logging.context import LoggingContext, run_in_background
 from synapse.rest import admin
 from synapse.rest.client import login, room
@@ -30,7 +32,7 @@ from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
-from tests.test_utils import event_injection
+from tests.test_utils import event_injection, make_awaitable
 
 logger = logging.getLogger(__name__)
 
@@ -280,13 +282,21 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
 
             # we poke this directly into _process_received_pdu, to avoid the
             # federation handler wanting to backfill the fake event.
+            state_handler = self.hs.get_state_handler()
+            context = self.get_success(
+                state_handler.compute_event_context(
+                    event,
+                    state_ids_before_event={
+                        (e.type, e.state_key): e.event_id for e in current_state
+                    },
+                    partial_state=False,
+                )
+            )
             self.get_success(
                 federation_event_handler._process_received_pdu(
                     self.OTHER_SERVER_NAME,
                     event,
-                    state_ids={
-                        (e.type, e.state_key): e.event_id for e in current_state
-                    },
+                    context,
                 )
             )
 
@@ -448,3 +458,121 @@ class EventFromPduTestCase(TestCase):
                 },
                 RoomVersions.V6,
             )
+
+
+class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
+    def test_failed_partial_join_is_clean(self) -> None:
+        """
+        Tests that, when failing to partial-join a room, we don't get stuck with
+        a partial-state flag on a room.
+        """
+
+        fed_handler = self.hs.get_federation_handler()
+        fed_client = fed_handler.federation_client
+
+        room_id = "!room:example.com"
+        membership_event = make_event_from_dict(
+            {
+                "room_id": room_id,
+                "type": "m.room.member",
+                "sender": "@alice:test",
+                "state_key": "@alice:test",
+                "content": {"membership": "join"},
+            },
+            RoomVersions.V10,
+        )
+
+        mock_make_membership_event = Mock(
+            return_value=make_awaitable(
+                (
+                    "example.com",
+                    membership_event,
+                    RoomVersions.V10,
+                )
+            )
+        )
+
+        EVENT_CREATE = make_event_from_dict(
+            {
+                "room_id": room_id,
+                "type": "m.room.create",
+                "sender": "@kristina:example.com",
+                "state_key": "",
+                "depth": 0,
+                "content": {"creator": "@kristina:example.com", "room_version": "10"},
+                "auth_events": [],
+                "origin_server_ts": 1,
+            },
+            room_version=RoomVersions.V10,
+        )
+        EVENT_CREATOR_MEMBERSHIP = make_event_from_dict(
+            {
+                "room_id": room_id,
+                "type": "m.room.member",
+                "sender": "@kristina:example.com",
+                "state_key": "@kristina:example.com",
+                "content": {"membership": "join"},
+                "depth": 1,
+                "prev_events": [EVENT_CREATE.event_id],
+                "auth_events": [EVENT_CREATE.event_id],
+                "origin_server_ts": 1,
+            },
+            room_version=RoomVersions.V10,
+        )
+        EVENT_INVITATION_MEMBERSHIP = make_event_from_dict(
+            {
+                "room_id": room_id,
+                "type": "m.room.member",
+                "sender": "@kristina:example.com",
+                "state_key": "@alice:test",
+                "content": {"membership": "invite"},
+                "depth": 2,
+                "prev_events": [EVENT_CREATOR_MEMBERSHIP.event_id],
+                "auth_events": [
+                    EVENT_CREATE.event_id,
+                    EVENT_CREATOR_MEMBERSHIP.event_id,
+                ],
+                "origin_server_ts": 1,
+            },
+            room_version=RoomVersions.V10,
+        )
+        mock_send_join = Mock(
+            return_value=make_awaitable(
+                SendJoinResult(
+                    membership_event,
+                    "example.com",
+                    state=[
+                        EVENT_CREATE,
+                        EVENT_CREATOR_MEMBERSHIP,
+                        EVENT_INVITATION_MEMBERSHIP,
+                    ],
+                    auth_chain=[
+                        EVENT_CREATE,
+                        EVENT_CREATOR_MEMBERSHIP,
+                        EVENT_INVITATION_MEMBERSHIP,
+                    ],
+                    partial_state=True,
+                    servers_in_room=["example.com"],
+                )
+            )
+        )
+
+        with patch.object(
+            fed_client, "make_membership_event", mock_make_membership_event
+        ), patch.object(fed_client, "send_join", mock_send_join):
+            # Join and check that our join event is rejected
+            # (The join event is rejected because it doesn't have any signatures)
+            join_exc = self.get_failure(
+                fed_handler.do_invite_join(["example.com"], room_id, "@alice:test", {}),
+                SynapseError,
+            )
+        self.assertIn("Join event was rejected", str(join_exc))
+
+        store = self.hs.get_datastores().main
+
+        # Check that we don't have a left-over partial_state entry.
+        self.assertFalse(
+            self.get_success(store.is_partial_state_room(room_id)),
+            f"Stale partial-stated room flag left over for {room_id} after a"
+            f" failed do_invite_join!",
+        )