diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index c86fc5df0d..a21cbe9fa8 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -76,7 +76,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
"""
Upgrading a room should work fine.
"""
- # THe user isn't in the room.
+ # The user isn't in the room.
roomless = self.register_user("roomless", "pass")
roomless_token = self.login(roomless, "pass")
@@ -263,3 +263,33 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.assertIn((EventTypes.SpaceChild, self.room_id), state_ids)
# The child that was removed should not be copied over.
self.assertNotIn((EventTypes.SpaceChild, old_room_id), state_ids)
+
+ def test_custom_room_type(self) -> None:
+ """Test upgrading a room that has a custom room type set."""
+ test_room_type = "com.example.my_custom_room_type"
+
+ # Create a room with a custom room type.
+ room_id = self.helper.create_room_as(
+ self.creator,
+ tok=self.creator_token,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: test_room_type}
+ },
+ )
+
+ # Upgrade the room!
+ channel = self._upgrade_room(room_id=room_id)
+ self.assertEqual(200, channel.code, channel.result)
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ state_ids = self.get_success(self.store.get_current_state_ids(new_room_id))
+
+ # Ensure the new room is the same type as the old room.
+ create_event = self.get_success(
+ self.store.get_event(state_ids[(EventTypes.Create, "")])
+ )
+ self.assertEqual(
+ create_event.content.get(EventContentFields.ROOM_TYPE), test_room_type
+ )
|