diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 5483f8f37f..fc2a6c569b 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -20,7 +20,7 @@
#
import urllib.parse
-from typing import Dict
+from typing import Dict, cast
from parameterized import parameterized
@@ -32,6 +32,7 @@ from synapse.http.server import JsonResource
from synapse.rest.admin import VersionServlet
from synapse.rest.client import login, media, room
from synapse.server import HomeServer
+from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@@ -227,10 +228,25 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media
response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
+ response_3 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
# Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:]
server_and_media_id_2 = response_2["content_uri"][6:]
+ server_and_media_id_3 = response_3["content_uri"][6:]
+
+ # Remove the hash from the media to simulate historic media.
+ self.get_success(
+ self.hs.get_datastores().main.update_local_media(
+ media_id=server_and_media_id_3.split("/")[1],
+ media_type="image/png",
+ upload_name=None,
+ media_length=123,
+ user_id=UserID.from_string(non_admin_user),
+ # Hack to force some media to have no hash.
+ sha256=cast(str, None),
+ )
+ )
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
@@ -244,12 +260,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.pump(1.0)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
- channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
+ channel.json_body, {"num_quarantined": 3}, "Expected 3 quarantined items"
)
# Attempt to access each piece of media
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_3)
def test_cannot_quarantine_safe_media(self) -> None:
self.register_user("user_admin", "pass", admin=True)
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 19c244cfcf..da0e9749aa 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -35,7 +35,7 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import SMALL_PNG
+from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG
from tests.unittest import override_config
VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds
@@ -598,23 +598,27 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
class QuarantineMediaByIDTestCase(_AdminMediaTests):
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.server_name = hs.hostname
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
+ def upload_media_and_return_media_id(self, data: bytes) -> str:
# Upload some media into the room
response = self.helper.upload_media(
- SMALL_PNG,
+ data,
tok=self.admin_user_tok,
expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
- self.media_id = server_and_media_id.split("/")[1]
+ return server_and_media_id.split("/")[1]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.server_name = hs.hostname
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+ self.media_id = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_2 = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_3 = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_other = self.upload_media_and_return_media_id(SMALL_CMYK_JPEG)
self.url = "/_synapse/admin/v1/media/%s/%s/%s"
@parameterized.expand(["quarantine", "unquarantine"])
@@ -686,6 +690,52 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
assert media_info is not None
self.assertFalse(media_info.quarantined_by)
+ def test_quarantine_media_match_hash(self) -> None:
+ """
+ Tests that quarantining removes all media with the same hash
+ """
+
+ media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
+ # quarantining
+ channel = self.make_request(
+ "POST",
+ self.url % ("quarantine", self.server_name, self.media_id),
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body)
+
+ # Test that ALL similar media was quarantined.
+ for media in [self.media_id, self.media_id_2, self.media_id_3]:
+ media_info = self.get_success(self.store.get_local_media(media))
+ assert media_info is not None
+ self.assertTrue(media_info.quarantined_by)
+
+ # Test that other media was not.
+ media_info = self.get_success(self.store.get_local_media(self.media_id_other))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
+ # remove from quarantine
+ channel = self.make_request(
+ "POST",
+ self.url % ("unquarantine", self.server_name, self.media_id),
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body)
+
+ # Test that ALL similar media is now reset.
+ for media in [self.media_id, self.media_id_2, self.media_id_3]:
+ media_info = self.get_success(self.store.get_local_media(media))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
def test_quarantine_protected_media(self) -> None:
"""
Tests that quarantining from protected media fails
|