summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_federation.py302
1 files changed, 277 insertions, 25 deletions
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index e2d3cff2a3..71068d16cd 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -20,7 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.api.errors import Codes
-from synapse.rest.client import login
+from synapse.rest.client import login, room
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -52,9 +52,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         ]
     )
     def test_requester_is_no_admin(self, method: str, url: str) -> None:
-        """
-        If the user is not a server admin, an error 403 is returned.
-        """
+        """If the user is not a server admin, an error 403 is returned."""
 
         self.register_user("user", "pass", admin=False)
         other_user_tok = self.login("user", "pass")
@@ -70,9 +68,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_invalid_parameter(self) -> None:
-        """
-        If parameters are invalid, an error is returned.
-        """
+        """If parameters are invalid, an error is returned."""
 
         # negative limit
         channel = self.make_request(
@@ -135,9 +131,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
     def test_limit(self) -> None:
-        """
-        Testing list of destinations with limit
-        """
+        """Testing list of destinations with limit"""
 
         number_destinations = 20
         self._create_destinations(number_destinations)
@@ -155,9 +149,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self._check_fields(channel.json_body["destinations"])
 
     def test_from(self) -> None:
-        """
-        Testing list of destinations with a defined starting point (from)
-        """
+        """Testing list of destinations with a defined starting point (from)"""
 
         number_destinations = 20
         self._create_destinations(number_destinations)
@@ -175,9 +167,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self._check_fields(channel.json_body["destinations"])
 
     def test_limit_and_from(self) -> None:
-        """
-        Testing list of destinations with a defined starting point and limit
-        """
+        """Testing list of destinations with a defined starting point and limit"""
 
         number_destinations = 20
         self._create_destinations(number_destinations)
@@ -195,9 +185,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self._check_fields(channel.json_body["destinations"])
 
     def test_next_token(self) -> None:
-        """
-        Testing that `next_token` appears at the right place
-        """
+        """Testing that `next_token` appears at the right place"""
 
         number_destinations = 20
         self._create_destinations(number_destinations)
@@ -256,9 +244,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertNotIn("next_token", channel.json_body)
 
     def test_list_all_destinations(self) -> None:
-        """
-        List all destinations.
-        """
+        """List all destinations."""
         number_destinations = 5
         self._create_destinations(number_destinations)
 
@@ -277,9 +263,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self._check_fields(channel.json_body["destinations"])
 
     def test_order_by(self) -> None:
-        """
-        Testing order list with parameter `order_by`
-        """
+        """Testing order list with parameter `order_by`"""
 
         def _order_test(
             expected_destination_list: List[str],
@@ -543,3 +527,271 @@ class FederationTestCase(unittest.HomeserverTestCase):
             self.assertIn("retry_interval", c)
             self.assertIn("failure_ts", c)
             self.assertIn("last_successful_stream_ordering", c)
+
+
+class DestinationMembershipTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastore()
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.dest = "sub0.example.com"
+        self.url = f"/_synapse/admin/v1/federation/destinations/{self.dest}/rooms"
+
+        # Record that we successfully contacted a destination in the DB.
+        self.get_success(
+            self.store.set_destination_retry_timings(self.dest, None, 0, 0)
+        )
+
+    def test_requester_is_no_admin(self) -> None:
+        """If the user is not a server admin, an error 403 is returned."""
+
+        self.register_user("user", "pass", admin=False)
+        other_user_tok = self.login("user", "pass")
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=other_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_invalid_parameter(self) -> None:
+        """If parameters are invalid, an error is returned."""
+
+        # negative limit
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=-5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # negative from
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=-5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # invalid search order
+        channel = self.make_request(
+            "GET",
+            self.url + "?dir=bar",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # invalid destination
+        channel = self.make_request(
+            "GET",
+            "/_synapse/admin/v1/federation/destinations/%s/rooms" % ("invalid",),
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_limit(self) -> None:
+        """Testing list of destinations with limit"""
+
+        number_rooms = 5
+        self._create_destination_rooms(number_rooms)
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=3",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), 3)
+        self.assertEqual(channel.json_body["next_token"], "3")
+        self._check_fields(channel.json_body["rooms"])
+
+    def test_from(self) -> None:
+        """Testing list of rooms with a defined starting point (from)"""
+
+        number_rooms = 10
+        self._create_destination_rooms(number_rooms)
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), 5)
+        self.assertNotIn("next_token", channel.json_body)
+        self._check_fields(channel.json_body["rooms"])
+
+    def test_limit_and_from(self) -> None:
+        """Testing list of rooms with a defined starting point and limit"""
+
+        number_rooms = 10
+        self._create_destination_rooms(number_rooms)
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=3&limit=5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(channel.json_body["next_token"], "8")
+        self.assertEqual(len(channel.json_body["rooms"]), 5)
+        self._check_fields(channel.json_body["rooms"])
+
+    def test_order_direction(self) -> None:
+        """Testing order list with parameter `dir`"""
+        number_rooms = 4
+        self._create_destination_rooms(number_rooms)
+
+        # get list in forward direction
+        channel_asc = self.make_request(
+            "GET",
+            self.url + "?dir=f",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel_asc.code, msg=channel_asc.json_body)
+        self.assertEqual(channel_asc.json_body["total"], number_rooms)
+        self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"]))
+        self._check_fields(channel_asc.json_body["rooms"])
+
+        # get list in backward direction
+        channel_desc = self.make_request(
+            "GET",
+            self.url + "?dir=b",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel_desc.code, msg=channel_desc.json_body)
+        self.assertEqual(channel_desc.json_body["total"], number_rooms)
+        self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"]))
+        self._check_fields(channel_desc.json_body["rooms"])
+
+        # test that both lists have different directions
+        for i in range(0, number_rooms):
+            self.assertEqual(
+                channel_asc.json_body["rooms"][i]["room_id"],
+                channel_desc.json_body["rooms"][number_rooms - 1 - i]["room_id"],
+            )
+
+    def test_next_token(self) -> None:
+        """Testing that `next_token` appears at the right place"""
+
+        number_rooms = 5
+        self._create_destination_rooms(number_rooms)
+
+        #  `next_token` does not appear
+        # Number of results is the number of entries
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does not appear
+        # Number of max results is larger than the number of entries
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=6",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does appear
+        # Number of max results is smaller than the number of entries
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=4",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), 4)
+        self.assertEqual(channel.json_body["next_token"], "4")
+
+        # Check
+        # Set `from` to value of `next_token` for request remaining entries
+        #  `next_token` does not appear
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=4",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), 1)
+        self.assertNotIn("next_token", channel.json_body)
+
+    def test_destination_rooms(self) -> None:
+        """Testing that request the list of rooms is successfully."""
+        number_rooms = 3
+        self._create_destination_rooms(number_rooms)
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(number_rooms, len(channel.json_body["rooms"]))
+        self._check_fields(channel.json_body["rooms"])
+
+    def _create_destination_rooms(self, number_rooms: int) -> None:
+        """Create a number rooms for destination
+
+        Args:
+            number_rooms: Number of rooms to be created
+        """
+        for _ in range(0, number_rooms):
+            room_id = self.helper.create_room_as(
+                self.admin_user, tok=self.admin_user_tok
+            )
+            self.get_success(
+                self.store.store_destination_rooms_entries((self.dest,), room_id, 1234)
+            )
+
+    def _check_fields(self, content: List[JsonDict]) -> None:
+        """Checks that the expected room attributes are present in content
+
+        Args:
+            content: List that is checked for content
+        """
+        for c in content:
+            self.assertIn("room_id", c)
+            self.assertIn("stream_ordering", c)