diff options
-rw-r--r-- | changelog.d/15008.misc | 1 | ||||
-rw-r--r-- | mypy.ini | 1 | ||||
-rw-r--r-- | synapse/rest/media/v1/media_storage.py | 7 | ||||
-rw-r--r-- | tests/rest/media/v1/test_media_storage.py | 49 |
4 files changed, 35 insertions, 23 deletions
diff --git a/changelog.d/15008.misc b/changelog.d/15008.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/15008.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/mypy.ini b/mypy.ini index 0efafb26b6..4598002c4a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -33,7 +33,6 @@ exclude = (?x) |synapse/storage/schema/ |tests/module_api/test_api.py - |tests/rest/media/v1/test_media_storage.py |tests/server.py )$ diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index a5c3de192f..db25848744 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -46,10 +46,9 @@ from ._base import FileInfo, Responder from .filepath import MediaFilePaths if TYPE_CHECKING: + from synapse.rest.media.v1.storage_provider import StorageProvider from synapse.server import HomeServer - from .storage_provider import StorageProviderWrapper - logger = logging.getLogger(__name__) @@ -68,7 +67,7 @@ class MediaStorage: hs: "HomeServer", local_media_directory: str, filepaths: MediaFilePaths, - storage_providers: Sequence["StorageProviderWrapper"], + storage_providers: Sequence["StorageProvider"], ): self.hs = hs self.reactor = hs.get_reactor() @@ -360,7 +359,7 @@ class ReadableFileWrapper: clock: Clock path: str - async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None: + async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None: """Reads the file in chunks and calls the callback with each chunk.""" with open(self.path, "rb") as file: diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index d18fc13c21..17a3b06a8e 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -16,7 +16,7 @@ import shutil import tempfile from binascii import unhexlify from io import BytesIO -from typing import Any, BinaryIO, Dict, List, Optional, Union +from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union from unittest.mock import Mock from urllib import parse @@ -32,6 +32,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.events import EventBase from synapse.events.spamcheck import load_legacy_spam_checkers +from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable from synapse.module_api import ModuleApi from synapse.rest import admin @@ -41,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from synapse.server import HomeServer -from synapse.types import RoomAlias +from synapse.types import JsonDict, RoomAlias from synapse.util import Clock from tests import unittest @@ -201,36 +202,46 @@ class _TestImage: ], ) class MediaRepoTests(unittest.HomeserverTestCase): - + test_image: ClassVar[_TestImage] hijack_auth = True user_id = "@test:user" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fetches = [] + self.fetches: List[ + Tuple[ + "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]", + str, + str, + Optional[QueryParams], + ] + ] = [] def get_file( destination: str, path: str, output_stream: BinaryIO, - args: Optional[Dict[str, Union[str, List[str]]]] = None, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, max_size: Optional[int] = None, - ) -> Deferred: - """ - Returns tuple[int,dict,str,int] of file length, response headers, - absolute URI, and response code. - """ + ignore_backoff: bool = False, + ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": + """A mock for MatrixFederationHttpClient.get_file.""" - def write_to(r): + def write_to( + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]] + ) -> Tuple[int, Dict[bytes, List[bytes]]]: data, response = r output_stream.write(data) return response - d = Deferred() - d.addCallback(write_to) + d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() self.fetches.append((d, destination, path, args)) - return make_deferred_yieldable(d) + # Note that this callback changes the value held by d. + d_after_callback = d.addCallback(write_to) + return make_deferred_yieldable(d_after_callback) + # Mock out the homeserver's MatrixFederationHttpClient client = Mock() client.get_file = get_file @@ -461,6 +472,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): # Synapse should regenerate missing thumbnails. origin, media_id = self.media_id.split("/") info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) + assert info is not None file_id = info["filesystem_id"] thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( @@ -581,7 +593,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): "thumbnail_method": method, "thumbnail_type": self.test_image.content_type, "thumbnail_length": 256, - "filesystem_id": f"thumbnail1{self.test_image.extension}", + "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}", }, { "thumbnail_width": 32, @@ -589,10 +601,10 @@ class MediaRepoTests(unittest.HomeserverTestCase): "thumbnail_method": method, "thumbnail_type": self.test_image.content_type, "thumbnail_length": 256, - "filesystem_id": f"thumbnail2{self.test_image.extension}", + "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}", }, ], - file_id=f"image{self.test_image.extension}", + file_id=f"image{self.test_image.extension.decode()}", url_cache=None, server_name=None, ) @@ -637,6 +649,7 @@ class TestSpamCheckerLegacy: self.config = config self.api = api + @staticmethod def parse_config(config: Dict[str, Any]) -> Dict[str, Any]: return config @@ -748,7 +761,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): async def check_media_file_for_spam( self, file_wrapper: ReadableFileWrapper, file_info: FileInfo - ) -> Union[Codes, Literal["NOT_SPAM"]]: + ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]: buf = BytesIO() await file_wrapper.write_chunks_to(buf.write) |