summary refs log tree commit diff
path: root/tests/rest/client/test_upgrade_room.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_upgrade_room.py')
-rw-r--r--tests/rest/client/test_upgrade_room.py83
1 files changed, 78 insertions, 5 deletions
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 98c1039d33..5e7bf97482 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -48,10 +48,14 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
         self.helper.join(self.room_id, self.other, tok=self.other_token)
 
     def _upgrade_room(
-        self, token: Optional[str] = None, room_id: Optional[str] = None
+        self,
+        token: Optional[str] = None,
+        room_id: Optional[str] = None,
+        expire_cache: bool = True,
     ) -> FakeChannel:
-        # We never want a cached response.
-        self.reactor.advance(5 * 60 + 1)
+        if expire_cache:
+            # We don't want a cached response.
+            self.reactor.advance(5 * 60 + 1)
 
         if room_id is None:
             room_id = self.room_id
@@ -72,9 +76,24 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, channel.result)
         self.assertIn("replacement_room", channel.json_body)
 
-    def test_not_in_room(self) -> None:
+        new_room_id = channel.json_body["replacement_room"]
+
+        # Check that the tombstone event points to the new room.
+        tombstone_event = self.get_success(
+            self.hs.get_storage_controllers().state.get_current_state_event(
+                self.room_id, EventTypes.Tombstone, ""
+            )
+        )
+        self.assertIsNotNone(tombstone_event)
+        self.assertEqual(new_room_id, tombstone_event.content["replacement_room"])
+
+        # Check that the new room exists.
+        room = self.get_success(self.store.get_room(new_room_id))
+        self.assertIsNotNone(room)
+
+    def test_never_in_room(self) -> None:
         """
-        Upgrading a room should work fine.
+        A user who has never been in the room cannot upgrade the room.
         """
         # The user isn't in the room.
         roomless = self.register_user("roomless", "pass")
@@ -83,6 +102,16 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
         channel = self._upgrade_room(roomless_token)
         self.assertEqual(403, channel.code, channel.result)
 
+    def test_left_room(self) -> None:
+        """
+        A user who is no longer in the room cannot upgrade the room.
+        """
+        # Remove the user from the room.
+        self.helper.leave(self.room_id, self.creator, tok=self.creator_token)
+
+        channel = self._upgrade_room(self.creator_token)
+        self.assertEqual(403, channel.code, channel.result)
+
     def test_power_levels(self) -> None:
         """
         Another user can upgrade the room if their power level is increased.
@@ -297,3 +326,47 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
         self.assertEqual(
             create_event.content.get(EventContentFields.ROOM_TYPE), test_room_type
         )
+
+    def test_second_upgrade_from_same_user(self) -> None:
+        """A second room upgrade from the same user is deduplicated."""
+        channel1 = self._upgrade_room()
+        self.assertEqual(200, channel1.code, channel1.result)
+
+        channel2 = self._upgrade_room(expire_cache=False)
+        self.assertEqual(200, channel2.code, channel2.result)
+
+        self.assertEqual(
+            channel1.json_body["replacement_room"],
+            channel2.json_body["replacement_room"],
+        )
+
+    def test_second_upgrade_after_delay(self) -> None:
+        """A second room upgrade is not deduplicated after some time has passed."""
+        channel1 = self._upgrade_room()
+        self.assertEqual(200, channel1.code, channel1.result)
+
+        channel2 = self._upgrade_room(expire_cache=True)
+        self.assertEqual(200, channel2.code, channel2.result)
+
+        self.assertNotEqual(
+            channel1.json_body["replacement_room"],
+            channel2.json_body["replacement_room"],
+        )
+
+    def test_second_upgrade_from_different_user(self) -> None:
+        """A second room upgrade from a different user is blocked."""
+        channel = self._upgrade_room()
+        self.assertEqual(200, channel.code, channel.result)
+
+        channel = self._upgrade_room(self.other_token, expire_cache=False)
+        self.assertEqual(400, channel.code, channel.result)
+
+    def test_first_upgrade_does_not_block_second(self) -> None:
+        """A second room upgrade is not blocked when a previous upgrade attempt was not
+        allowed.
+        """
+        channel = self._upgrade_room(self.other_token)
+        self.assertEqual(403, channel.code, channel.result)
+
+        channel = self._upgrade_room(expire_cache=False)
+        self.assertEqual(200, channel.code, channel.result)