diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index d18fc13c21..17a3b06a8e 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
import tempfile
from binascii import unhexlify
from io import BytesIO
-from typing import Any, BinaryIO, Dict, List, Optional, Union
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib import parse
@@ -32,6 +32,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.rest import admin
@@ -41,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from synapse.server import HomeServer
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock
from tests import unittest
@@ -201,36 +202,46 @@ class _TestImage:
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
-
+ test_image: ClassVar[_TestImage]
hijack_auth = True
user_id = "@test:user"
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.fetches = []
+ self.fetches: List[
+ Tuple[
+ "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]",
+ str,
+ str,
+ Optional[QueryParams],
+ ]
+ ] = []
def get_file(
destination: str,
path: str,
output_stream: BinaryIO,
- args: Optional[Dict[str, Union[str, List[str]]]] = None,
+ args: Optional[QueryParams] = None,
+ retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
- ) -> Deferred:
- """
- Returns tuple[int,dict,str,int] of file length, response headers,
- absolute URI, and response code.
- """
+ ignore_backoff: bool = False,
+ ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
+ """A mock for MatrixFederationHttpClient.get_file."""
- def write_to(r):
+ def write_to(
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+ ) -> Tuple[int, Dict[bytes, List[bytes]]]:
data, response = r
output_stream.write(data)
return response
- d = Deferred()
- d.addCallback(write_to)
+ d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args))
- return make_deferred_yieldable(d)
+ # Note that this callback changes the value held by d.
+ d_after_callback = d.addCallback(write_to)
+ return make_deferred_yieldable(d_after_callback)
+ # Mock out the homeserver's MatrixFederationHttpClient
client = Mock()
client.get_file = get_file
@@ -461,6 +472,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
# Synapse should regenerate missing thumbnails.
origin, media_id = self.media_id.split("/")
info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
+ assert info is not None
file_id = info["filesystem_id"]
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
@@ -581,7 +593,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"thumbnail_method": method,
"thumbnail_type": self.test_image.content_type,
"thumbnail_length": 256,
- "filesystem_id": f"thumbnail1{self.test_image.extension}",
+ "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}",
},
{
"thumbnail_width": 32,
@@ -589,10 +601,10 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"thumbnail_method": method,
"thumbnail_type": self.test_image.content_type,
"thumbnail_length": 256,
- "filesystem_id": f"thumbnail2{self.test_image.extension}",
+ "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}",
},
],
- file_id=f"image{self.test_image.extension}",
+ file_id=f"image{self.test_image.extension.decode()}",
url_cache=None,
server_name=None,
)
@@ -637,6 +649,7 @@ class TestSpamCheckerLegacy:
self.config = config
self.api = api
+ @staticmethod
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config
@@ -748,7 +761,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
- ) -> Union[Codes, Literal["NOT_SPAM"]]:
+ ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write)
|