summary refs log tree commit diff
path: root/tests/rest/admin/test_admin.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/admin/test_admin.py')
-rw-r--r--tests/rest/admin/test_admin.py134
1 files changed, 40 insertions, 94 deletions
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 3adadcb46b..849d00ab4d 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -12,18 +12,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
 import urllib.parse
 from http import HTTPStatus
-from unittest.mock import Mock
+from typing import List
 
-from twisted.internet.defer import Deferred
+from parameterized import parameterized
+
+from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.http.server import JsonResource
-from synapse.logging.context import make_deferred_yieldable
 from synapse.rest.admin import VersionServlet
 from synapse.rest.client import groups, login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeSite, make_request
@@ -33,12 +35,12 @@ from tests.test_utils import SMALL_PNG
 class VersionTestCase(unittest.HomeserverTestCase):
     url = "/_synapse/admin/v1/server_version"
 
-    def create_test_resource(self):
+    def create_test_resource(self) -> JsonResource:
         resource = JsonResource(self.hs)
         VersionServlet(self.hs).register(resource)
         return resource
 
-    def test_version_string(self):
+    def test_version_string(self) -> None:
         channel = self.make_request("GET", self.url, shorthand=False)
 
         self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
@@ -54,14 +56,14 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         groups.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
         self.other_user = self.register_user("user", "pass")
         self.other_user_token = self.login("user", "pass")
 
-    def test_delete_group(self):
+    def test_delete_group(self) -> None:
         # Create a new group
         channel = self.make_request(
             "POST",
@@ -112,7 +114,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
         self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token))
 
-    def _check_group(self, group_id, expect_code):
+    def _check_group(self, group_id: str, expect_code: int) -> None:
         """Assert that trying to fetch the given group results in the given
         HTTP status code
         """
@@ -124,7 +126,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(expect_code, channel.code, msg=channel.json_body)
 
-    def _get_groups_user_is_in(self, access_token):
+    def _get_groups_user_is_in(self, access_token: str) -> List[str]:
         """Returns the list of groups the user is in (given their access token)"""
         channel = self.make_request("GET", b"/joined_groups", access_token=access_token)
 
@@ -143,59 +145,15 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # Allow for uploading and downloading to/from the media repo
         self.media_repo = hs.get_media_repository_resource()
         self.download_resource = self.media_repo.children[b"download"]
         self.upload_resource = self.media_repo.children[b"upload"]
 
-    def make_homeserver(self, reactor, clock):
-
-        self.fetches = []
-
-        async def get_file(destination, path, output_stream, args=None, max_size=None):
-            """
-            Returns tuple[int,dict,str,int] of file length, response headers,
-            absolute URI, and response code.
-            """
-
-            def write_to(r):
-                data, response = r
-                output_stream.write(data)
-                return response
-
-            d = Deferred()
-            d.addCallback(write_to)
-            self.fetches.append((d, destination, path, args))
-            return await make_deferred_yieldable(d)
-
-        client = Mock()
-        client.get_file = get_file
-
-        self.storage_path = self.mktemp()
-        self.media_store_path = self.mktemp()
-        os.mkdir(self.storage_path)
-        os.mkdir(self.media_store_path)
-
-        config = self.default_config()
-        config["media_store_path"] = self.media_store_path
-        config["thumbnail_requirements"] = {}
-        config["max_image_pixels"] = 2000000
-
-        provider_config = {
-            "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
-            "store_local": True,
-            "store_synchronous": False,
-            "store_remote": True,
-            "config": {"directory": self.storage_path},
-        }
-        config["media_storage_providers"] = [provider_config]
-
-        hs = self.setup_test_homeserver(config=config, federation_http_client=client)
-
-        return hs
-
-    def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
+    def _ensure_quarantined(
+        self, admin_user_tok: str, server_and_media_id: str
+    ) -> None:
         """Ensure a piece of media is quarantined when trying to access it."""
         channel = make_request(
             self.reactor,
@@ -216,12 +174,18 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             ),
         )
 
-    def test_quarantine_media_requires_admin(self):
+    @parameterized.expand(
+        [
+            # Attempt quarantine media APIs as non-admin
+            "/_synapse/admin/v1/media/quarantine/example.org/abcde12345",
+            # And the roomID/userID endpoint
+            "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine",
+        ]
+    )
+    def test_quarantine_media_requires_admin(self, url: str) -> None:
         self.register_user("nonadmin", "pass", admin=False)
         non_admin_user_tok = self.login("nonadmin", "pass")
 
-        # Attempt quarantine media APIs as non-admin
-        url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
         channel = self.make_request(
             "POST",
             url.encode("ascii"),
@@ -235,22 +199,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             msg="Expected forbidden on quarantining media as a non-admin",
         )
 
-        # And the roomID/userID endpoint
-        url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
-        channel = self.make_request(
-            "POST",
-            url.encode("ascii"),
-            access_token=non_admin_user_tok,
-        )
-
-        # Expect a forbidden error
-        self.assertEqual(
-            HTTPStatus.FORBIDDEN,
-            channel.code,
-            msg="Expected forbidden on quarantining media as a non-admin",
-        )
-
-    def test_quarantine_media_by_id(self):
+    def test_quarantine_media_by_id(self) -> None:
         self.register_user("id_admin", "pass", admin=True)
         admin_user_tok = self.login("id_admin", "pass")
 
@@ -295,7 +244,15 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         # Attempt to access the media
         self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
 
-    def test_quarantine_all_media_in_room(self, override_url_template=None):
+    @parameterized.expand(
+        [
+            # regular API path
+            "/_synapse/admin/v1/room/%s/media/quarantine",
+            # deprecated API path
+            "/_synapse/admin/v1/quarantine_media/%s",
+        ]
+    )
+    def test_quarantine_all_media_in_room(self, url: str) -> None:
         self.register_user("room_admin", "pass", admin=True)
         admin_user_tok = self.login("room_admin", "pass")
 
@@ -333,16 +290,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             tok=non_admin_user_tok,
         )
 
-        # Quarantine all media in the room
-        if override_url_template:
-            url = override_url_template % urllib.parse.quote(room_id)
-        else:
-            url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
-                room_id
-            )
         channel = self.make_request(
             "POST",
-            url,
+            url % urllib.parse.quote(room_id),
             access_token=admin_user_tok,
         )
         self.pump(1.0)
@@ -359,11 +309,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
         self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
 
-    def test_quarantine_all_media_in_room_deprecated_api_path(self):
-        # Perform the above test with the deprecated API path
-        self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")
-
-    def test_quarantine_all_media_by_user(self):
+    def test_quarantine_all_media_by_user(self) -> None:
         self.register_user("user_admin", "pass", admin=True)
         admin_user_tok = self.login("user_admin", "pass")
 
@@ -401,7 +347,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
         self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
 
-    def test_cannot_quarantine_safe_media(self):
+    def test_cannot_quarantine_safe_media(self) -> None:
         self.register_user("user_admin", "pass", admin=True)
         admin_user_tok = self.login("user_admin", "pass")
 
@@ -475,7 +421,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -488,7 +434,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
         self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}"
         self.url_status = "/_synapse/admin/v1/purge_history_status/"
 
-    def test_purge_history(self):
+    def test_purge_history(self) -> None:
         """
         Simple test of purge history API.
         Test only that is is possible to call, get status HTTPStatus.OK and purge_id.