summary refs log tree commit diff
path: root/tests/rest/admin/test_federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/admin/test_federation.py')
-rw-r--r--tests/rest/admin/test_federation.py357
1 files changed, 328 insertions, 29 deletions
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index b70350b6f1..71068d16cd 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -20,7 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.api.errors import Codes
-from synapse.rest.client import login
+from synapse.rest.client import login, room
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -43,20 +43,22 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
     @parameterized.expand(
         [
-            ("/_synapse/admin/v1/federation/destinations",),
-            ("/_synapse/admin/v1/federation/destinations/dummy",),
+            ("GET", "/_synapse/admin/v1/federation/destinations"),
+            ("GET", "/_synapse/admin/v1/federation/destinations/dummy"),
+            (
+                "POST",
+                "/_synapse/admin/v1/federation/destinations/dummy/reset_connection",
+            ),
         ]
     )
-    def test_requester_is_no_admin(self, url: str) -> None:
-        """
-        If the user is not a server admin, an error 403 is returned.
-        """
+    def test_requester_is_no_admin(self, method: str, url: str) -> None:
+        """If the user is not a server admin, an error 403 is returned."""
 
         self.register_user("user", "pass", admin=False)
         other_user_tok = self.login("user", "pass")
 
         channel = self.make_request(
-            "GET",
+            method,
             url,
             content={},
             access_token=other_user_tok,
@@ -66,9 +68,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_invalid_parameter(self) -> None:
-        """
-        If parameters are invalid, an error is returned.
-        """
+        """If parameters are invalid, an error is returned."""
 
         # negative limit
         channel = self.make_request(
@@ -120,10 +120,18 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
+        # invalid destination
+        channel = self.make_request(
+            "POST",
+            self.url + "/dummy/reset_connection",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
     def test_limit(self) -> None:
-        """
-        Testing list of destinations with limit
-        """
+        """Testing list of destinations with limit"""
 
         number_destinations = 20
         self._create_destinations(number_destinations)
@@ -141,9 +149,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self._check_fields(channel.json_body["destinations"])
 
     def test_from(self) -> None:
-        """
-        Testing list of destinations with a defined starting point (from)
-        """
+        """Testing list of destinations with a defined starting point (from)"""
 
         number_destinations = 20
         self._create_destinations(number_destinations)
@@ -161,9 +167,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self._check_fields(channel.json_body["destinations"])
 
     def test_limit_and_from(self) -> None:
-        """
-        Testing list of destinations with a defined starting point and limit
-        """
+        """Testing list of destinations with a defined starting point and limit"""
 
         number_destinations = 20
         self._create_destinations(number_destinations)
@@ -181,9 +185,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self._check_fields(channel.json_body["destinations"])
 
     def test_next_token(self) -> None:
-        """
-        Testing that `next_token` appears at the right place
-        """
+        """Testing that `next_token` appears at the right place"""
 
         number_destinations = 20
         self._create_destinations(number_destinations)
@@ -242,9 +244,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertNotIn("next_token", channel.json_body)
 
     def test_list_all_destinations(self) -> None:
-        """
-        List all destinations.
-        """
+        """List all destinations."""
         number_destinations = 5
         self._create_destinations(number_destinations)
 
@@ -263,9 +263,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self._check_fields(channel.json_body["destinations"])
 
     def test_order_by(self) -> None:
-        """
-        Testing order list with parameter `order_by`
-        """
+        """Testing order list with parameter `order_by`"""
 
         def _order_test(
             expected_destination_list: List[str],
@@ -444,6 +442,39 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertIsNone(channel.json_body["failure_ts"])
         self.assertIsNone(channel.json_body["last_successful_stream_ordering"])
 
+    def test_destination_reset_connection(self) -> None:
+        """Reset timeouts and wake up destination."""
+        self._create_destination("sub0.example.com", 100, 100, 100)
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/sub0.example.com/reset_connection",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+        retry_timings = self.get_success(
+            self.store.get_destination_retry_timings("sub0.example.com")
+        )
+        self.assertIsNone(retry_timings)
+
+    def test_destination_reset_connection_not_required(self) -> None:
+        """Try to reset timeouts of a destination with no timeouts and get an error."""
+        self._create_destination("sub0.example.com", None, 0, 0)
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/sub0.example.com/reset_connection",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(
+            "The retry timing does not need to be reset for this destination.",
+            channel.json_body["error"],
+        )
+
     def _create_destination(
         self,
         destination: str,
@@ -496,3 +527,271 @@ class FederationTestCase(unittest.HomeserverTestCase):
             self.assertIn("retry_interval", c)
             self.assertIn("failure_ts", c)
             self.assertIn("last_successful_stream_ordering", c)
+
+
+class DestinationMembershipTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastore()
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.dest = "sub0.example.com"
+        self.url = f"/_synapse/admin/v1/federation/destinations/{self.dest}/rooms"
+
+        # Record that we successfully contacted a destination in the DB.
+        self.get_success(
+            self.store.set_destination_retry_timings(self.dest, None, 0, 0)
+        )
+
+    def test_requester_is_no_admin(self) -> None:
+        """If the user is not a server admin, an error 403 is returned."""
+
+        self.register_user("user", "pass", admin=False)
+        other_user_tok = self.login("user", "pass")
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=other_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_invalid_parameter(self) -> None:
+        """If parameters are invalid, an error is returned."""
+
+        # negative limit
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=-5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # negative from
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=-5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # invalid search order
+        channel = self.make_request(
+            "GET",
+            self.url + "?dir=bar",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # invalid destination
+        channel = self.make_request(
+            "GET",
+            "/_synapse/admin/v1/federation/destinations/%s/rooms" % ("invalid",),
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_limit(self) -> None:
+        """Testing list of destinations with limit"""
+
+        number_rooms = 5
+        self._create_destination_rooms(number_rooms)
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=3",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), 3)
+        self.assertEqual(channel.json_body["next_token"], "3")
+        self._check_fields(channel.json_body["rooms"])
+
+    def test_from(self) -> None:
+        """Testing list of rooms with a defined starting point (from)"""
+
+        number_rooms = 10
+        self._create_destination_rooms(number_rooms)
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), 5)
+        self.assertNotIn("next_token", channel.json_body)
+        self._check_fields(channel.json_body["rooms"])
+
+    def test_limit_and_from(self) -> None:
+        """Testing list of rooms with a defined starting point and limit"""
+
+        number_rooms = 10
+        self._create_destination_rooms(number_rooms)
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=3&limit=5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(channel.json_body["next_token"], "8")
+        self.assertEqual(len(channel.json_body["rooms"]), 5)
+        self._check_fields(channel.json_body["rooms"])
+
+    def test_order_direction(self) -> None:
+        """Testing order list with parameter `dir`"""
+        number_rooms = 4
+        self._create_destination_rooms(number_rooms)
+
+        # get list in forward direction
+        channel_asc = self.make_request(
+            "GET",
+            self.url + "?dir=f",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel_asc.code, msg=channel_asc.json_body)
+        self.assertEqual(channel_asc.json_body["total"], number_rooms)
+        self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"]))
+        self._check_fields(channel_asc.json_body["rooms"])
+
+        # get list in backward direction
+        channel_desc = self.make_request(
+            "GET",
+            self.url + "?dir=b",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel_desc.code, msg=channel_desc.json_body)
+        self.assertEqual(channel_desc.json_body["total"], number_rooms)
+        self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"]))
+        self._check_fields(channel_desc.json_body["rooms"])
+
+        # test that both lists have different directions
+        for i in range(0, number_rooms):
+            self.assertEqual(
+                channel_asc.json_body["rooms"][i]["room_id"],
+                channel_desc.json_body["rooms"][number_rooms - 1 - i]["room_id"],
+            )
+
+    def test_next_token(self) -> None:
+        """Testing that `next_token` appears at the right place"""
+
+        number_rooms = 5
+        self._create_destination_rooms(number_rooms)
+
+        #  `next_token` does not appear
+        # Number of results is the number of entries
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does not appear
+        # Number of max results is larger than the number of entries
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=6",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does appear
+        # Number of max results is smaller than the number of entries
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=4",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), 4)
+        self.assertEqual(channel.json_body["next_token"], "4")
+
+        # Check
+        # Set `from` to value of `next_token` for request remaining entries
+        #  `next_token` does not appear
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=4",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(len(channel.json_body["rooms"]), 1)
+        self.assertNotIn("next_token", channel.json_body)
+
+    def test_destination_rooms(self) -> None:
+        """Testing that request the list of rooms is successfully."""
+        number_rooms = 3
+        self._create_destination_rooms(number_rooms)
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], number_rooms)
+        self.assertEqual(number_rooms, len(channel.json_body["rooms"]))
+        self._check_fields(channel.json_body["rooms"])
+
+    def _create_destination_rooms(self, number_rooms: int) -> None:
+        """Create a number rooms for destination
+
+        Args:
+            number_rooms: Number of rooms to be created
+        """
+        for _ in range(0, number_rooms):
+            room_id = self.helper.create_room_as(
+                self.admin_user, tok=self.admin_user_tok
+            )
+            self.get_success(
+                self.store.store_destination_rooms_entries((self.dest,), room_id, 1234)
+            )
+
+    def _check_fields(self, content: List[JsonDict]) -> None:
+        """Checks that the expected room attributes are present in content
+
+        Args:
+            content: List that is checked for content
+        """
+        for c in content:
+            self.assertIn("room_id", c)
+            self.assertIn("stream_ordering", c)