diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 2f3df5f88f..7dd86d0c27 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1821,3 +1821,163 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
+
+
+class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.room_owner = self.register_user("room_owner", "test")
+ self.room_owner_tok = self.login("room_owner", "test")
+
+ self.room_id = self.helper.create_room_as(
+ self.room_owner, tok=self.room_owner_tok
+ )
+
+ self.alias = "#alias:test"
+ self._set_alias_via_directory(self.alias)
+
+ def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
+ url = "/_matrix/client/r0/directory/room/" + alias
+ data = {"room_id": self.room_id}
+ request_data = json.dumps(data)
+
+ request, channel = self.make_request(
+ "PUT", url, request_data, access_token=self.room_owner_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
+ """Calls the endpoint under test. returns the json response object."""
+ request, channel = self.make_request(
+ "GET",
+ "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
+ access_token=self.room_owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+ res = channel.json_body
+ self.assertIsInstance(res, dict)
+ return res
+
+ def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
+ """Calls the endpoint under test. returns the json response object."""
+ request, channel = self.make_request(
+ "PUT",
+ "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
+ json.dumps(content),
+ access_token=self.room_owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+ res = channel.json_body
+ self.assertIsInstance(res, dict)
+ return res
+
+ def test_canonical_alias(self):
+ """Test a basic alias message."""
+ # There is no canonical alias to start with.
+ self._get_canonical_alias(expected_code=404)
+
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias})
+
+ # Now remove the alias.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_alt_aliases(self):
+ """Test a canonical alias message with alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alt_aliases": [self.alias]})
+
+ # Now remove the alt_aliases.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_alias_alt_aliases(self):
+ """Test a canonical alias message with an alias and alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Now remove the alias and alt_aliases.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_partial_modify(self):
+ """Test removing only the alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Now remove the alt_aliases.
+ self._set_canonical_alias({"alias": self.alias})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias})
+
+ def test_add_alias(self):
+ """Test removing only the alt_aliases."""
+ # Create an additional alias.
+ second_alias = "#second:test"
+ self._set_alias_via_directory(second_alias)
+
+ # Add the canonical alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Then add the second alias.
+ self._set_canonical_alias(
+ {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
+ )
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(
+ res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
+ )
+
+ def test_bad_data(self):
+ """Invalid data for alt_aliases should cause errors."""
+ self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": 0}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": 1}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": False}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
+
+ def test_bad_alias(self):
+ """An alias which does not point to the room raises a SynapseError."""
+ self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
|