summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_directory.py66
-rw-r--r--tests/rest/client/v1/test_rooms.py160
-rw-r--r--tests/test_types.py2
3 files changed, 191 insertions, 37 deletions
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 27b916aed4..3397cfa485 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -88,6 +88,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         )
 
     def test_delete_alias_not_allowed(self):
+        """Removing an alias should be denied if a user does not have the proper permissions."""
         room_id = "!8765qwer:test"
         self.get_success(
             self.store.create_room_alias_association(self.my_room, room_id, ["test"])
@@ -101,6 +102,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         )
 
     def test_delete_alias(self):
+        """Removing an alias should work when a user does has the proper permissions."""
         room_id = "!8765qwer:test"
         user_id = "@user:test"
         self.get_success(
@@ -159,30 +161,42 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         )
 
         self.test_alias = "#test:test"
-        self.room_alias = RoomAlias.from_string(self.test_alias)
+        self.room_alias = self._add_alias(self.test_alias)
+
+    def _add_alias(self, alias: str) -> RoomAlias:
+        """Add an alias to the test room."""
+        room_alias = RoomAlias.from_string(alias)
 
         # Create a new alias to this room.
         self.get_success(
             self.store.create_room_alias_association(
-                self.room_alias, self.room_id, ["test"], self.admin_user
+                room_alias, self.room_id, ["test"], self.admin_user
             )
         )
+        return room_alias
 
-    def test_remove_alias(self):
-        """Removing an alias that is the canonical alias should remove it there too."""
-        # Set this new alias as the canonical alias for this room
+    def _set_canonical_alias(self, content):
+        """Configure the canonical alias state on the room."""
         self.helper.send_state(
-            self.room_id,
-            "m.room.canonical_alias",
-            {"alias": self.test_alias, "alt_aliases": [self.test_alias]},
-            tok=self.admin_user_tok,
+            self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok,
         )
 
-        data = self.get_success(
+    def _get_canonical_alias(self):
+        """Get the canonical alias state of the room."""
+        return self.get_success(
             self.state_handler.get_current_state(
                 self.room_id, EventTypes.CanonicalAlias, ""
             )
         )
+
+    def test_remove_alias(self):
+        """Removing an alias that is the canonical alias should remove it there too."""
+        # Set this new alias as the canonical alias for this room
+        self._set_canonical_alias(
+            {"alias": self.test_alias, "alt_aliases": [self.test_alias]}
+        )
+
+        data = self._get_canonical_alias()
         self.assertEqual(data["content"]["alias"], self.test_alias)
         self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
 
@@ -193,11 +207,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        data = self.get_success(
-            self.state_handler.get_current_state(
-                self.room_id, EventTypes.CanonicalAlias, ""
-            )
-        )
+        data = self._get_canonical_alias()
         self.assertNotIn("alias", data["content"])
         self.assertNotIn("alt_aliases", data["content"])
 
@@ -205,29 +215,17 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         """Removing an alias listed as in alt_aliases should remove it there too."""
         # Create a second alias.
         other_test_alias = "#test2:test"
-        other_room_alias = RoomAlias.from_string(other_test_alias)
-        self.get_success(
-            self.store.create_room_alias_association(
-                other_room_alias, self.room_id, ["test"], self.admin_user
-            )
-        )
+        other_room_alias = self._add_alias(other_test_alias)
 
         # Set the alias as the canonical alias for this room.
-        self.helper.send_state(
-            self.room_id,
-            "m.room.canonical_alias",
+        self._set_canonical_alias(
             {
                 "alias": self.test_alias,
                 "alt_aliases": [self.test_alias, other_test_alias],
-            },
-            tok=self.admin_user_tok,
+            }
         )
 
-        data = self.get_success(
-            self.state_handler.get_current_state(
-                self.room_id, EventTypes.CanonicalAlias, ""
-            )
-        )
+        data = self._get_canonical_alias()
         self.assertEqual(data["content"]["alias"], self.test_alias)
         self.assertEqual(
             data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
@@ -240,11 +238,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        data = self.get_success(
-            self.state_handler.get_current_state(
-                self.room_id, EventTypes.CanonicalAlias, ""
-            )
-        )
+        data = self._get_canonical_alias()
         self.assertEqual(data["content"]["alias"], self.test_alias)
         self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 2f3df5f88f..7dd86d0c27 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1821,3 +1821,163 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
         )
         self.render(request)
         self.assertEqual(channel.code, expected_code, channel.result)
+
+
+class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        directory.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.room_owner = self.register_user("room_owner", "test")
+        self.room_owner_tok = self.login("room_owner", "test")
+
+        self.room_id = self.helper.create_room_as(
+            self.room_owner, tok=self.room_owner_tok
+        )
+
+        self.alias = "#alias:test"
+        self._set_alias_via_directory(self.alias)
+
+    def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
+        url = "/_matrix/client/r0/directory/room/" + alias
+        data = {"room_id": self.room_id}
+        request_data = json.dumps(data)
+
+        request, channel = self.make_request(
+            "PUT", url, request_data, access_token=self.room_owner_tok
+        )
+        self.render(request)
+        self.assertEqual(channel.code, expected_code, channel.result)
+
+    def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
+        """Calls the endpoint under test. returns the json response object."""
+        request, channel = self.make_request(
+            "GET",
+            "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
+            access_token=self.room_owner_tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, expected_code, channel.result)
+        res = channel.json_body
+        self.assertIsInstance(res, dict)
+        return res
+
+    def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
+        """Calls the endpoint under test. returns the json response object."""
+        request, channel = self.make_request(
+            "PUT",
+            "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
+            json.dumps(content),
+            access_token=self.room_owner_tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, expected_code, channel.result)
+        res = channel.json_body
+        self.assertIsInstance(res, dict)
+        return res
+
+    def test_canonical_alias(self):
+        """Test a basic alias message."""
+        # There is no canonical alias to start with.
+        self._get_canonical_alias(expected_code=404)
+
+        # Create an alias.
+        self._set_canonical_alias({"alias": self.alias})
+
+        # Canonical alias now exists!
+        res = self._get_canonical_alias()
+        self.assertEqual(res, {"alias": self.alias})
+
+        # Now remove the alias.
+        self._set_canonical_alias({})
+
+        # There is an alias event, but it is empty.
+        res = self._get_canonical_alias()
+        self.assertEqual(res, {})
+
+    def test_alt_aliases(self):
+        """Test a canonical alias message with alt_aliases."""
+        # Create an alias.
+        self._set_canonical_alias({"alt_aliases": [self.alias]})
+
+        # Canonical alias now exists!
+        res = self._get_canonical_alias()
+        self.assertEqual(res, {"alt_aliases": [self.alias]})
+
+        # Now remove the alt_aliases.
+        self._set_canonical_alias({})
+
+        # There is an alias event, but it is empty.
+        res = self._get_canonical_alias()
+        self.assertEqual(res, {})
+
+    def test_alias_alt_aliases(self):
+        """Test a canonical alias message with an alias and alt_aliases."""
+        # Create an alias.
+        self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+        # Canonical alias now exists!
+        res = self._get_canonical_alias()
+        self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
+
+        # Now remove the alias and alt_aliases.
+        self._set_canonical_alias({})
+
+        # There is an alias event, but it is empty.
+        res = self._get_canonical_alias()
+        self.assertEqual(res, {})
+
+    def test_partial_modify(self):
+        """Test removing only the alt_aliases."""
+        # Create an alias.
+        self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+        # Canonical alias now exists!
+        res = self._get_canonical_alias()
+        self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
+
+        # Now remove the alt_aliases.
+        self._set_canonical_alias({"alias": self.alias})
+
+        # There is an alias event, but it is empty.
+        res = self._get_canonical_alias()
+        self.assertEqual(res, {"alias": self.alias})
+
+    def test_add_alias(self):
+        """Test removing only the alt_aliases."""
+        # Create an additional alias.
+        second_alias = "#second:test"
+        self._set_alias_via_directory(second_alias)
+
+        # Add the canonical alias.
+        self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+        # Then add the second alias.
+        self._set_canonical_alias(
+            {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
+        )
+
+        # Canonical alias now exists!
+        res = self._get_canonical_alias()
+        self.assertEqual(
+            res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
+        )
+
+    def test_bad_data(self):
+        """Invalid data for alt_aliases should cause errors."""
+        self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
+        self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
+        self._set_canonical_alias({"alt_aliases": 0}, expected_code=400)
+        self._set_canonical_alias({"alt_aliases": 1}, expected_code=400)
+        self._set_canonical_alias({"alt_aliases": False}, expected_code=400)
+        self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
+        self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
+
+    def test_bad_alias(self):
+        """An alias which does not point to the room raises a SynapseError."""
+        self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
+        self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
diff --git a/tests/test_types.py b/tests/test_types.py
index 8d97c751ea..480bea1bdc 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -75,7 +75,7 @@ class GroupIDTestCase(unittest.TestCase):
                 self.fail("Parsing '%s' should raise exception" % id_string)
             except SynapseError as exc:
                 self.assertEqual(400, exc.code)
-                self.assertEqual("M_UNKNOWN", exc.errcode)
+                self.assertEqual("M_INVALID_PARAM", exc.errcode)
 
 
 class MapUsernameTestCase(unittest.TestCase):