diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index cb1c6fbb80..2b9b56da95 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -369,6 +369,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
time_now_ms=self.clock.time_msec(),
upload_name=None,
filesystem_id="xyz",
+ sha256="abcdefg12345",
)
)
diff --git a/tests/media/test_media_retention.py b/tests/media/test_media_retention.py
index 417d17ebd2..d8f4f57c8c 100644
--- a/tests/media/test_media_retention.py
+++ b/tests/media/test_media_retention.py
@@ -31,6 +31,9 @@ from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util import Clock
+from synapse.util.stringutils import (
+ random_string,
+)
from tests import unittest
from tests.unittest import override_config
@@ -65,7 +68,6 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
# quarantined media) into both the local store and the remote cache, plus
# one additional local media that is marked as protected from quarantine.
media_repository = hs.get_media_repository()
- test_media_content = b"example string"
def _create_media_and_set_attributes(
last_accessed_ms: Optional[int],
@@ -73,12 +75,14 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
is_protected: Optional[bool] = False,
) -> MXCUri:
# "Upload" some media to the local media store
+ # If the meda
+ random_content = bytes(random_string(24), "utf-8")
mxc_uri: MXCUri = self.get_success(
media_repository.create_content(
media_type="text/plain",
upload_name=None,
- content=io.BytesIO(test_media_content),
- content_length=len(test_media_content),
+ content=io.BytesIO(random_content),
+ content_length=len(random_content),
auth_user=UserID.from_string(test_user_id),
)
)
@@ -129,6 +133,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
time_now_ms=clock.time_msec(),
upload_name="testfile.txt",
filesystem_id="abcdefg12345",
+ sha256=random_string(24),
)
)
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index c2e0e592d7..35e16a99ba 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -42,6 +42,7 @@ from twisted.web.resource import Resource
from synapse.api.errors import Codes, HttpResponseException
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
+from synapse.http.client import ByteWriteable
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.media._base import FileInfo, ThumbnailInfo
@@ -59,7 +60,7 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
-from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG
+from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG, SMALL_PNG_SHA256
from tests.unittest import override_config
from tests.utils import default_config
@@ -1257,3 +1258,107 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
)
assert channel.code == 502
assert channel.json_body["errcode"] == "M_TOO_LARGE"
+
+
+def read_body(
+ response: IResponse, stream: ByteWriteable, max_size: Optional[int]
+) -> Deferred:
+ d: Deferred = defer.Deferred()
+ stream.write(SMALL_PNG)
+ d.callback(len(SMALL_PNG))
+ return d
+
+
+class MediaHashesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ media.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+ self.store = hs.get_datastores().main
+ self.client = hs.get_federation_http_client()
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
+
+ def test_ensure_correct_sha256(self) -> None:
+ """Check that the hash does not change"""
+ media = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
+ mxc = media.get("content_uri")
+ assert mxc
+ store_media = self.get_success(self.store.get_local_media(mxc[11:]))
+ assert store_media
+ self.assertEqual(
+ store_media.sha256,
+ SMALL_PNG_SHA256,
+ )
+
+ def test_ensure_multiple_correct_sha256(self) -> None:
+ """Check that two media items have the same hash."""
+ media_a = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
+ mxc_a = media_a.get("content_uri")
+ assert mxc_a
+ store_media_a = self.get_success(self.store.get_local_media(mxc_a[11:]))
+ assert store_media_a
+
+ media_b = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
+ mxc_b = media_b.get("content_uri")
+ assert mxc_b
+ store_media_b = self.get_success(self.store.get_local_media(mxc_b[11:]))
+ assert store_media_b
+
+ self.assertNotEqual(
+ store_media_a.media_id,
+ store_media_b.media_id,
+ )
+ self.assertEqual(
+ store_media_a.sha256,
+ store_media_b.sha256,
+ )
+
+ @override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
+ # mock actually reading file body
+ @patch(
+ "synapse.http.matrixfederationclient.read_body_with_max_size",
+ read_body,
+ )
+ def test_ensure_correct_sha256_federated(self) -> None:
+ """Check that federated media have the same hash."""
+
+ # Mock getting a file over federation
+ async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+ resp = MagicMock(spec=IResponse)
+ resp.code = 200
+ resp.length = 500
+ resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+ resp.phrase = b"OK"
+ return resp
+
+ self.client._send_request = _send_request # type: ignore
+
+ # first request should go through
+ channel = self.make_request(
+ "GET",
+ "/_matrix/media/v3/download/remote.org/abc",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel.code == 200
+ store_media = self.get_success(
+ self.store.get_cached_remote_media("remote.org", "abc")
+ )
+ assert store_media
+ self.assertEqual(
+ store_media.sha256,
+ SMALL_PNG_SHA256,
+ )
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 5483f8f37f..fc2a6c569b 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -20,7 +20,7 @@
#
import urllib.parse
-from typing import Dict
+from typing import Dict, cast
from parameterized import parameterized
@@ -32,6 +32,7 @@ from synapse.http.server import JsonResource
from synapse.rest.admin import VersionServlet
from synapse.rest.client import login, media, room
from synapse.server import HomeServer
+from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@@ -227,10 +228,25 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media
response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
+ response_3 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
# Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:]
server_and_media_id_2 = response_2["content_uri"][6:]
+ server_and_media_id_3 = response_3["content_uri"][6:]
+
+ # Remove the hash from the media to simulate historic media.
+ self.get_success(
+ self.hs.get_datastores().main.update_local_media(
+ media_id=server_and_media_id_3.split("/")[1],
+ media_type="image/png",
+ upload_name=None,
+ media_length=123,
+ user_id=UserID.from_string(non_admin_user),
+ # Hack to force some media to have no hash.
+ sha256=cast(str, None),
+ )
+ )
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
@@ -244,12 +260,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.pump(1.0)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
- channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
+ channel.json_body, {"num_quarantined": 3}, "Expected 3 quarantined items"
)
# Attempt to access each piece of media
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_3)
def test_cannot_quarantine_safe_media(self) -> None:
self.register_user("user_admin", "pass", admin=True)
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 19c244cfcf..da0e9749aa 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -35,7 +35,7 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import SMALL_PNG
+from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG
from tests.unittest import override_config
VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds
@@ -598,23 +598,27 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
class QuarantineMediaByIDTestCase(_AdminMediaTests):
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.server_name = hs.hostname
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
+ def upload_media_and_return_media_id(self, data: bytes) -> str:
# Upload some media into the room
response = self.helper.upload_media(
- SMALL_PNG,
+ data,
tok=self.admin_user_tok,
expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
- self.media_id = server_and_media_id.split("/")[1]
+ return server_and_media_id.split("/")[1]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.server_name = hs.hostname
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+ self.media_id = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_2 = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_3 = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_other = self.upload_media_and_return_media_id(SMALL_CMYK_JPEG)
self.url = "/_synapse/admin/v1/media/%s/%s/%s"
@parameterized.expand(["quarantine", "unquarantine"])
@@ -686,6 +690,52 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
assert media_info is not None
self.assertFalse(media_info.quarantined_by)
+ def test_quarantine_media_match_hash(self) -> None:
+ """
+ Tests that quarantining removes all media with the same hash
+ """
+
+ media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
+ # quarantining
+ channel = self.make_request(
+ "POST",
+ self.url % ("quarantine", self.server_name, self.media_id),
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body)
+
+ # Test that ALL similar media was quarantined.
+ for media in [self.media_id, self.media_id_2, self.media_id_3]:
+ media_info = self.get_success(self.store.get_local_media(media))
+ assert media_info is not None
+ self.assertTrue(media_info.quarantined_by)
+
+ # Test that other media was not.
+ media_info = self.get_success(self.store.get_local_media(self.media_id_other))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
+ # remove from quarantine
+ channel = self.make_request(
+ "POST",
+ self.url % ("unquarantine", self.server_name, self.media_id),
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body)
+
+ # Test that ALL similar media is now reset.
+ for media in [self.media_id, self.media_id_2, self.media_id_3]:
+ media_info = self.get_success(self.store.get_local_media(media))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
def test_quarantine_protected_media(self) -> None:
"""
Tests that quarantining from protected media fails
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index 0e3e370ee8..1ea2a5c884 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -137,6 +137,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
time_now_ms=clock.time_msec(),
upload_name="test.png",
filesystem_id=file_id,
+ sha256=file_id,
)
)
self.register_user("user", "password")
@@ -2593,6 +2594,7 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
time_now_ms=self.clock.time_msec(),
upload_name="remote_test.png",
filesystem_id=file_id,
+ sha256=file_id,
)
)
@@ -2725,6 +2727,7 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
time_now_ms=self.clock.time_msec(),
upload_name="remote_test.png",
filesystem_id=file_id,
+ sha256=file_id,
)
)
diff --git a/tests/rest/media/test_domain_blocking.py b/tests/rest/media/test_domain_blocking.py
index 49d81f4b28..26453f70dd 100644
--- a/tests/rest/media/test_domain_blocking.py
+++ b/tests/rest/media/test_domain_blocking.py
@@ -61,6 +61,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
time_now_ms=clock.time_msec(),
upload_name="test.png",
filesystem_id=file_id,
+ sha256=file_id,
)
)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e3611852b7..3e6fd03600 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -139,6 +139,8 @@ SMALL_PNG = unhexlify(
b"0000001f15c4890000000a49444154789c63000100000500010d"
b"0a2db40000000049454e44ae426082"
)
+# The SHA256 hexdigest for the above bytes.
+SMALL_PNG_SHA256 = "ebf4f635a17d10d6eb46ba680b70142419aa3220f228001a036d311a22ee9d2a"
# A small CMYK-encoded JPEG image used in some tests.
#
|