summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15008.misc1
-rw-r--r--mypy.ini1
-rw-r--r--synapse/rest/media/v1/media_storage.py7
-rw-r--r--tests/rest/media/v1/test_media_storage.py49
4 files changed, 35 insertions, 23 deletions
diff --git a/changelog.d/15008.misc b/changelog.d/15008.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/15008.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/mypy.ini b/mypy.ini
index 0efafb26b6..4598002c4a 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -33,7 +33,6 @@ exclude = (?x)
    |synapse/storage/schema/
 
    |tests/module_api/test_api.py
-   |tests/rest/media/v1/test_media_storage.py
    |tests/server.py
    )$
 
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a5c3de192f..db25848744 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -46,10 +46,9 @@ from ._base import FileInfo, Responder
 from .filepath import MediaFilePaths
 
 if TYPE_CHECKING:
+    from synapse.rest.media.v1.storage_provider import StorageProvider
     from synapse.server import HomeServer
 
-    from .storage_provider import StorageProviderWrapper
-
 logger = logging.getLogger(__name__)
 
 
@@ -68,7 +67,7 @@ class MediaStorage:
         hs: "HomeServer",
         local_media_directory: str,
         filepaths: MediaFilePaths,
-        storage_providers: Sequence["StorageProviderWrapper"],
+        storage_providers: Sequence["StorageProvider"],
     ):
         self.hs = hs
         self.reactor = hs.get_reactor()
@@ -360,7 +359,7 @@ class ReadableFileWrapper:
     clock: Clock
     path: str
 
-    async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
+    async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:
         """Reads the file in chunks and calls the callback with each chunk."""
 
         with open(self.path, "rb") as file:
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index d18fc13c21..17a3b06a8e 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
 import tempfile
 from binascii import unhexlify
 from io import BytesIO
-from typing import Any, BinaryIO, Dict, List, Optional, Union
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
 from unittest.mock import Mock
 from urllib import parse
 
@@ -32,6 +32,7 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.api.errors import Codes
 from synapse.events import EventBase
 from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.http.types import QueryParams
 from synapse.logging.context import make_deferred_yieldable
 from synapse.module_api import ModuleApi
 from synapse.rest import admin
@@ -41,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
 from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
 from synapse.server import HomeServer
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
 from synapse.util import Clock
 
 from tests import unittest
@@ -201,36 +202,46 @@ class _TestImage:
     ],
 )
 class MediaRepoTests(unittest.HomeserverTestCase):
-
+    test_image: ClassVar[_TestImage]
     hijack_auth = True
     user_id = "@test:user"
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
-        self.fetches = []
+        self.fetches: List[
+            Tuple[
+                "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]",
+                str,
+                str,
+                Optional[QueryParams],
+            ]
+        ] = []
 
         def get_file(
             destination: str,
             path: str,
             output_stream: BinaryIO,
-            args: Optional[Dict[str, Union[str, List[str]]]] = None,
+            args: Optional[QueryParams] = None,
+            retry_on_dns_fail: bool = True,
             max_size: Optional[int] = None,
-        ) -> Deferred:
-            """
-            Returns tuple[int,dict,str,int] of file length, response headers,
-            absolute URI, and response code.
-            """
+            ignore_backoff: bool = False,
+        ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
+            """A mock for MatrixFederationHttpClient.get_file."""
 
-            def write_to(r):
+            def write_to(
+                r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+            ) -> Tuple[int, Dict[bytes, List[bytes]]]:
                 data, response = r
                 output_stream.write(data)
                 return response
 
-            d = Deferred()
-            d.addCallback(write_to)
+            d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
             self.fetches.append((d, destination, path, args))
-            return make_deferred_yieldable(d)
+            # Note that this callback changes the value held by d.
+            d_after_callback = d.addCallback(write_to)
+            return make_deferred_yieldable(d_after_callback)
 
+        # Mock out the homeserver's MatrixFederationHttpClient
         client = Mock()
         client.get_file = get_file
 
@@ -461,6 +472,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         # Synapse should regenerate missing thumbnails.
         origin, media_id = self.media_id.split("/")
         info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
+        assert info is not None
         file_id = info["filesystem_id"]
 
         thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
@@ -581,7 +593,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                         "thumbnail_method": method,
                         "thumbnail_type": self.test_image.content_type,
                         "thumbnail_length": 256,
-                        "filesystem_id": f"thumbnail1{self.test_image.extension}",
+                        "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}",
                     },
                     {
                         "thumbnail_width": 32,
@@ -589,10 +601,10 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                         "thumbnail_method": method,
                         "thumbnail_type": self.test_image.content_type,
                         "thumbnail_length": 256,
-                        "filesystem_id": f"thumbnail2{self.test_image.extension}",
+                        "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}",
                     },
                 ],
-                file_id=f"image{self.test_image.extension}",
+                file_id=f"image{self.test_image.extension.decode()}",
                 url_cache=None,
                 server_name=None,
             )
@@ -637,6 +649,7 @@ class TestSpamCheckerLegacy:
         self.config = config
         self.api = api
 
+    @staticmethod
     def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
         return config
 
@@ -748,7 +761,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
 
     async def check_media_file_for_spam(
         self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
-    ) -> Union[Codes, Literal["NOT_SPAM"]]:
+    ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
         buf = BytesIO()
         await file_wrapper.write_chunks_to(buf.write)