diff options
Diffstat (limited to 'tests/replication/test_multi_media_repo.py')
-rw-r--r-- | tests/replication/test_multi_media_repo.py | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 96cdf2c45b..1527b4a82d 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -18,12 +18,14 @@ from typing import Optional, Tuple from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol +from twisted.test.proto_helpers import MemoryReactor from twisted.web.http import HTTPChannel from twisted.web.server import Request from synapse.rest import admin from synapse.rest.client import login from synapse.server import HomeServer +from synapse.util import Clock from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "pass") self.access_token = self.login("user", "pass") self.reactor.lookups["example.com"] = "1.2.3.4" - def default_config(self): + def default_config(self) -> dict: conf = super().default_config() conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] return conf @@ -122,7 +124,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): return channel, request - def test_basic(self): + def test_basic(self) -> None: """Test basic fetching of remote media from a single worker.""" hs1 = self.make_worker_hs("synapse.app.generic_worker") @@ -138,7 +140,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.result["body"], b"Hello!") - def test_download_simple_file_race(self): + def test_download_simple_file_race(self) -> None: """Test that fetching remote media from two different processes at the same time works. """ @@ -177,7 +179,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): # We expect only one new file to have been persisted. self.assertEqual(start_count + 1, self._count_remote_media()) - def test_download_image_race(self): + def test_download_image_race(self) -> None: """Test that fetching remote *images* from two different processes at the same time works. @@ -229,7 +231,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): return sum(len(files) for _, _, files in os.walk(path)) -def get_connection_factory(): +def get_connection_factory() -> TestServerTLSConnectionFactory: # this needs to happen once, but not until we are ready to run the first test global test_server_connection_factory if test_server_connection_factory is None: @@ -263,6 +265,6 @@ def _build_test_server( return server_tls_factory.buildProtocol(None) -def _log_request(request): +def _log_request(request: Request) -> None: """Implements Factory.log, which is expected by Request.finish""" logger.info("Completed request %s", request) |