summary refs log tree commit diff
path: root/tests/replication/test_multi_media_repo.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/test_multi_media_repo.py')
-rw-r--r--tests/replication/test_multi_media_repo.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 96cdf2c45b..1527b4a82d 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -18,12 +18,14 @@ from typing import Optional, Tuple
 from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
 from twisted.internet.protocol import Factory
 from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
+from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.http import HTTPChannel
 from twisted.web.server import Request
 
 from synapse.rest import admin
 from synapse.rest.client import login
 from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
 from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = self.register_user("user", "pass")
         self.access_token = self.login("user", "pass")
 
         self.reactor.lookups["example.com"] = "1.2.3.4"
 
-    def default_config(self):
+    def default_config(self) -> dict:
         conf = super().default_config()
         conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
         return conf
@@ -122,7 +124,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
 
         return channel, request
 
-    def test_basic(self):
+    def test_basic(self) -> None:
         """Test basic fetching of remote media from a single worker."""
         hs1 = self.make_worker_hs("synapse.app.generic_worker")
 
@@ -138,7 +140,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.result["body"], b"Hello!")
 
-    def test_download_simple_file_race(self):
+    def test_download_simple_file_race(self) -> None:
         """Test that fetching remote media from two different processes at the
         same time works.
         """
@@ -177,7 +179,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         # We expect only one new file to have been persisted.
         self.assertEqual(start_count + 1, self._count_remote_media())
 
-    def test_download_image_race(self):
+    def test_download_image_race(self) -> None:
         """Test that fetching remote *images* from two different processes at
         the same time works.
 
@@ -229,7 +231,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         return sum(len(files) for _, _, files in os.walk(path))
 
 
-def get_connection_factory():
+def get_connection_factory() -> TestServerTLSConnectionFactory:
     # this needs to happen once, but not until we are ready to run the first test
     global test_server_connection_factory
     if test_server_connection_factory is None:
@@ -263,6 +265,6 @@ def _build_test_server(
     return server_tls_factory.buildProtocol(None)
 
 
-def _log_request(request):
+def _log_request(request: Request) -> None:
     """Implements Factory.log, which is expected by Request.finish"""
     logger.info("Completed request %s", request)