summary refs log tree commit diff
path: root/tests/rest/media
diff options
context:
space:
mode:
authorDavid Teller <D.O.Teller@gmail.com>2022-07-11 18:52:10 +0200
committerGitHub <noreply@github.com>2022-07-11 16:52:10 +0000
commit11f811470ff94dedc4232072b7f9ff099d4fcbd6 (patch)
tree982ec784ffb679557df2949e0b0060b096fd562c /tests/rest/media
parentFix to-device messages not being sent to MSC3202-enabled appservices (#13235) (diff)
downloadsynapse-11f811470ff94dedc4232072b7f9ff099d4fcbd6.tar.xz
Uniformize spam-checker API, part 5: expand other spam-checker callbacks to return `Tuple[Codes, dict]` (#13044)
Signed-off-by: David Teller <davidt@element.io>
Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
Diffstat (limited to 'tests/rest/media')
-rw-r--r--tests/rest/media/v1/test_media_storage.py70
1 files changed, 67 insertions, 3 deletions
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 1c67e1ca91..79727c430f 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -23,11 +23,13 @@ from urllib import parse
 import attr
 from parameterized import parameterized, parameterized_class
 from PIL import Image as Image
+from typing_extensions import Literal
 
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
 from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.api.errors import Codes
 from synapse.events import EventBase
 from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.logging.context import make_deferred_yieldable
@@ -570,9 +572,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         )
 
 
-class TestSpamChecker:
+class TestSpamCheckerLegacy:
     """A spam checker module that rejects all media that includes the bytes
     `evil`.
+
+    Uses the legacy Spam-Checker API.
     """
 
     def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
@@ -613,7 +617,7 @@ class TestSpamChecker:
         return b"evil" in buf.getvalue()
 
 
-class SpamCheckerTestCase(unittest.HomeserverTestCase):
+class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
     servlets = [
         login.register_servlets,
         admin.register_servlets,
@@ -637,7 +641,8 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
             {
                 "spam_checker": [
                     {
-                        "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+                        "module": TestSpamCheckerLegacy.__module__
+                        + ".TestSpamCheckerLegacy",
                         "config": {},
                     }
                 ]
@@ -662,3 +667,62 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
         self.helper.upload_media(
             self.upload_resource, data, tok=self.tok, expect_code=400
         )
+
+
+EVIL_DATA = b"Some evil data"
+EVIL_DATA_EXPERIMENT = b"Some evil data to trigger the experimental tuple API"
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.user = self.register_user("user", "pass")
+        self.tok = self.login("user", "pass")
+
+        # Allow for uploading and downloading to/from the media repo
+        self.media_repo = hs.get_media_repository_resource()
+        self.download_resource = self.media_repo.children[b"download"]
+        self.upload_resource = self.media_repo.children[b"upload"]
+
+        hs.get_module_api().register_spam_checker_callbacks(
+            check_media_file_for_spam=self.check_media_file_for_spam
+        )
+
+    async def check_media_file_for_spam(
+        self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
+    ) -> Union[Codes, Literal["NOT_SPAM"]]:
+        buf = BytesIO()
+        await file_wrapper.write_chunks_to(buf.write)
+
+        if buf.getvalue() == EVIL_DATA:
+            return Codes.FORBIDDEN
+        elif buf.getvalue() == EVIL_DATA_EXPERIMENT:
+            return (Codes.FORBIDDEN, {})
+        else:
+            return "NOT_SPAM"
+
+    def test_upload_innocent(self) -> None:
+        """Attempt to upload some innocent data that should be allowed."""
+        self.helper.upload_media(
+            self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
+        )
+
+    def test_upload_ban(self) -> None:
+        """Attempt to upload some data that includes bytes "evil", which should
+        get rejected by the spam checker.
+        """
+
+        self.helper.upload_media(
+            self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400
+        )
+
+        self.helper.upload_media(
+            self.upload_resource,
+            EVIL_DATA_EXPERIMENT,
+            tok=self.tok,
+            expect_code=400,
+        )