summary refs log tree commit diff
path: root/tests/replication
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-02-06 09:55:00 -0500
committerGitHub <noreply@github.com>2023-02-06 09:55:00 -0500
commit156cd88eefe7db100e5cdba48174c709975b93ca (patch)
treebf4059f81c6ba16439ef6dfa19a4e016057da20d /tests/replication
parentExpect type stubs from canonicaljson (#14992) (diff)
downloadsynapse-156cd88eefe7db100e5cdba48174c709975b93ca.tar.xz
Add missing type hints to tests.replication. (#14987)
Diffstat (limited to 'tests/replication')
-rw-r--r--tests/replication/_base.py70
-rw-r--r--tests/replication/http/test__base.py2
-rw-r--r--tests/replication/slave/storage/_base.py25
-rw-r--r--tests/replication/slave/storage/test_events.py85
-rw-r--r--tests/replication/tcp/streams/test_account_data.py4
-rw-r--r--tests/replication/tcp/streams/test_events.py18
-rw-r--r--tests/replication/tcp/streams/test_federation.py2
-rw-r--r--tests/replication/tcp/streams/test_partial_state.py2
-rw-r--r--tests/replication/tcp/streams/test_typing.py33
-rw-r--r--tests/replication/tcp/test_commands.py6
-rw-r--r--tests/replication/tcp/test_remote_server_up.py8
-rw-r--r--tests/replication/test_auth.py14
-rw-r--r--tests/replication/test_client_reader_shard.py4
-rw-r--r--tests/replication/test_federation_ack.py12
-rw-r--r--tests/replication/test_federation_sender_shard.py10
-rw-r--r--tests/replication/test_module_cache_invalidation.py2
-rw-r--r--tests/replication/test_multi_media_repo.py16
-rw-r--r--tests/replication/test_pusher_shard.py11
-rw-r--r--tests/replication/test_sharded_event_persister.py14
19 files changed, 189 insertions, 149 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 6a7174b333..46a8e2013e 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -16,7 +16,9 @@ from collections import defaultdict
 from typing import Any, Dict, List, Optional, Set, Tuple
 
 from twisted.internet.address import IPv4Address
-from twisted.internet.protocol import Protocol
+from twisted.internet.protocol import Protocol, connectionDone
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import GenericWorkerServer
@@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
     if not hiredis:
         skip = "Requires hiredis"
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
         self.streamer = hs.get_replication_streamer()
@@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
             repl_handler,
         )
 
-        self._client_transport = None
-        self._server_transport = None
+        self._client_transport: Optional[FakeTransport] = None
+        self._server_transport: Optional[FakeTransport] = None
 
     def create_resource_dict(self) -> Dict[str, Resource]:
         d = super().create_resource_dict()
@@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         config["worker_replication_http_port"] = "8765"
         return config
 
-    def _build_replication_data_handler(self):
+    def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
         return TestReplicationDataHandler(self.worker_hs)
 
-    def reconnect(self):
+    def reconnect(self) -> None:
         if self._client_transport:
             self.client.close()
 
@@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._server_transport = FakeTransport(self.client, self.reactor)
         self.server.makeConnection(self._server_transport)
 
-    def disconnect(self):
+    def disconnect(self) -> None:
         if self._client_transport:
             self._client_transport = None
             self.client.close()
@@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
             self._server_transport = None
             self.server.close()
 
-    def replicate(self):
+    def replicate(self) -> None:
         """Tell the master side of replication that something has happened, and then
         wait for the replication to occur.
         """
@@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         requests: List[SynapseRequest] = []
         real_request_factory = channel.requestFactory
 
-        def request_factory(*args, **kwargs):
+        def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
             request = real_request_factory(*args, **kwargs)
             requests.append(request)
             return request
@@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
     def assert_request_is_get_repl_stream_updates(
         self, request: SynapseRequest, stream_name: str
-    ):
+    ) -> None:
         """Asserts that the given request is a HTTP replication request for
         fetching updates for given stream.
         """
@@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         base["redis"] = {"enabled": True}
         return base
 
-    def setUp(self):
+    def setUp(self) -> None:
         super().setUp()
 
         # build a replication server
@@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             lambda: self._handle_http_replication_attempt(self.hs, 8765),
         )
 
-    def create_test_resource(self):
+    def create_test_resource(self) -> ReplicationRestResource:
         """Overrides `HomeserverTestCase.create_test_resource`."""
         # We override this so that it automatically registers all the HTTP
         # replication servlets, without having to explicitly do that in all
@@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         return resource
 
     def make_worker_hs(
-        self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
+        self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
     ) -> HomeServer:
         """Make a new worker HS instance, correctly connecting replcation
         stream to the master HS.
@@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         config["worker_replication_http_port"] = "8765"
         return config
 
-    def replicate(self):
+    def replicate(self) -> None:
         """Tell the master side of replication that something has happened, and then
         wait for the replication to occur.
         """
         self.streamer.on_notifier_poke()
         self.pump()
 
-    def _handle_http_replication_attempt(self, hs, repl_port):
+    def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
         """Handles a connection attempt to the given HS replication HTTP
         listener on the given port.
         """
@@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # inside `connecTCP` before the connection has been passed back to the
         # code that requested the TCP connection.
 
-    def connect_any_redis_attempts(self):
+    def connect_any_redis_attempts(self) -> None:
         """If redis is enabled we need to deal with workers connecting to a
         redis server. We don't want to use a real Redis server so we use a
         fake one.
@@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             self.assertEqual(host, "localhost")
             self.assertEqual(port, 6379)
 
-            client_protocol = client_factory.buildProtocol(None)
-            server_protocol = self._redis_server.buildProtocol(None)
+            client_address = IPv4Address("TCP", "127.0.0.1", 6379)
+            client_protocol = client_factory.buildProtocol(client_address)
+
+            server_address = IPv4Address("TCP", host, port)
+            server_protocol = self._redis_server.buildProtocol(server_address)
 
             client_to_server_transport = FakeTransport(
                 server_protocol, self.reactor, client_protocol
@@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
         # list of received (stream_name, token, row) tuples
         self.received_rdata_rows: List[Tuple[str, int, Any]] = []
 
-    async def on_rdata(self, stream_name, instance_name, token, rows):
+    async def on_rdata(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ) -> None:
         await super().on_rdata(stream_name, instance_name, token, rows)
         for r in rows:
             self.received_rdata_rows.append((stream_name, token, r))
@@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
 class FakeRedisPubSubServer:
     """A fake Redis server for pub/sub."""
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._subscribers_by_channel: Dict[
             bytes, Set["FakeRedisPubSubProtocol"]
         ] = defaultdict(set)
 
-    def add_subscriber(self, conn, channel: bytes):
+    def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
         """A connection has called SUBSCRIBE"""
         self._subscribers_by_channel[channel].add(conn)
 
-    def remove_subscriber(self, conn):
+    def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
         """A connection has lost connection"""
         for subscribers in self._subscribers_by_channel.values():
             subscribers.discard(conn)
 
-    def publish(self, conn, channel: bytes, msg) -> int:
+    def publish(
+        self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
+    ) -> int:
         """A connection want to publish a message to subscribers."""
         for sub in self._subscribers_by_channel[channel]:
             sub.send(["message", channel, msg])
 
         return len(self._subscribers_by_channel)
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
         return FakeRedisPubSubProtocol(self)
 
 
@@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
         self._server = server
         self._reader = hiredis.Reader()
 
-    def dataReceived(self, data):
+    def dataReceived(self, data: bytes) -> None:
         self._reader.feed(data)
 
         # We might get multiple messages in one packet.
@@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
 
             self.handle_command(msg[0], *msg[1:])
 
-    def handle_command(self, command, *args):
+    def handle_command(self, command: bytes, *args: bytes) -> None:
         """Received a Redis command from the client."""
 
         # We currently only support pub/sub.
@@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
             self.send("PONG")
 
         else:
-            raise Exception(f"Unknown command: {command}")
+            raise Exception(f"Unknown command: {command!r}")
 
-    def send(self, msg):
+    def send(self, msg: object) -> None:
         """Send a message back to the client."""
         assert self.transport is not None
 
@@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
         self.transport.write(raw)
         self.transport.flush()
 
-    def encode(self, obj):
+    def encode(self, obj: object) -> str:
         """Encode an object to its Redis format.
 
         Supports: strings/bytes, integers and list/tuples.
@@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
 
         raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
 
-    def connectionLost(self, reason):
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         self._server.remove_subscriber(self)
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index e03d9b4cc0..9be11ab802 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -74,7 +74,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
 class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
     """Tests for `ReplicationEndpoint` cancellation."""
 
-    def create_test_resource(self):
+    def create_test_resource(self) -> JsonResource:
         """Overrides `HomeserverTestCase.create_test_resource`."""
         resource = JsonResource(self.hs)
 
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index c5705256e6..4c9b494344 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -13,35 +13,42 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Any, Iterable, Optional
 from unittest.mock import Mock
 
-from tests.replication._base import BaseStreamTestCase
+from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.server import HomeServer
+from synapse.util import Clock
 
-class BaseSlavedStoreTestCase(BaseStreamTestCase):
-    def make_homeserver(self, reactor, clock):
+from tests.replication._base import BaseStreamTestCase
 
-        hs = self.setup_test_homeserver(federation_client=Mock())
 
-        return hs
+class BaseSlavedStoreTestCase(BaseStreamTestCase):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        return self.setup_test_homeserver(federation_client=Mock())
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
 
         self.reconnect()
 
         self.master_store = hs.get_datastores().main
         self.slaved_store = self.worker_hs.get_datastores().main
-        self._storage_controllers = hs.get_storage_controllers()
+        persistence = hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self.persistance = persistence
 
-    def replicate(self):
+    def replicate(self) -> None:
         """Tell the master side of replication that something has happened, and then
         wait for the replication to occur.
         """
         self.streamer.on_notifier_poke()
         self.pump(0.1)
 
-    def check(self, method, args, expected_result=None):
+    def check(
+        self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
+    ) -> None:
         master_result = self.get_success(getattr(self.master_store, method)(*args))
         slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
         if expected_result is not None:
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index dce71f7334..ddca9d696c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,15 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Iterable, Optional
+from typing import Any, Callable, Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 from parameterized import parameterized
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import ReceiptTypes
 from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
+from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.handlers.room import RoomEventSource
+from synapse.server import HomeServer
 from synapse.storage.databases.main.event_push_actions import (
     NotifCounts,
     RoomNotifCounts,
@@ -28,6 +32,7 @@ from synapse.storage.databases.main.event_push_actions import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
 from synapse.types import PersistedEventPosition
+from synapse.util import Clock
 
 from tests.server import FakeTransport
 
@@ -41,19 +46,19 @@ ROOM_ID = "!room:test"
 logger = logging.getLogger(__name__)
 
 
-def dict_equals(self, other):
+def dict_equals(self: EventBase, other: EventBase) -> bool:
     me = encode_canonical_json(self.get_pdu_json())
     them = encode_canonical_json(other.get_pdu_json())
     return me == them
 
 
-def patch__eq__(cls):
+def patch__eq__(cls: object) -> Callable[[], None]:
     eq = getattr(cls, "__eq__", None)
-    cls.__eq__ = dict_equals
+    cls.__eq__ = dict_equals  # type: ignore[assignment]
 
-    def unpatch():
+    def unpatch() -> None:
         if eq is not None:
-            cls.__eq__ = eq
+            cls.__eq__ = eq  # type: ignore[assignment]
 
     return unpatch
 
@@ -62,14 +67,14 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
 
     STORE_TYPE = EventsWorkerStore
 
-    def setUp(self):
+    def setUp(self) -> None:
         # Patch up the equality operator for events so that we can check
         # whether lists of events match using assertEqual
-        self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
-        return super().setUp()
+        self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
+        super().setUp()
 
-    def prepare(self, *args, **kwargs):
-        super().prepare(*args, **kwargs)
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        super().prepare(reactor, clock, hs)
 
         self.get_success(
             self.master_store.store_room(
@@ -80,10 +85,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             )
         )
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         [unpatch() for unpatch in self.unpatches]
 
-    def test_get_latest_event_ids_in_room(self):
+    def test_get_latest_event_ids_in_room(self) -> None:
         create = self.persist(type="m.room.create", key="", creator=USER_ID)
         self.replicate()
         self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
@@ -97,7 +102,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         self.replicate()
         self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
 
-    def test_redactions(self):
+    def test_redactions(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.persist(type="m.room.member", key=USER_ID, membership="join")
 
@@ -117,7 +122,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         )
         self.check("get_event", [msg.event_id], redacted)
 
-    def test_backfilled_redactions(self):
+    def test_backfilled_redactions(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.persist(type="m.room.member", key=USER_ID, membership="join")
 
@@ -139,7 +144,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         )
         self.check("get_event", [msg.event_id], redacted)
 
-    def test_invites(self):
+    def test_invites(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
         event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
@@ -163,7 +168,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         )
 
     @parameterized.expand([(True,), (False,)])
-    def test_push_actions_for_user(self, send_receipt: bool):
+    def test_push_actions_for_user(self, send_receipt: bool) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.persist(type="m.room.member", key=USER_ID, membership="join")
         self.persist(
@@ -219,7 +224,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             ),
         )
 
-    def test_get_rooms_for_user_with_stream_ordering(self):
+    def test_get_rooms_for_user_with_stream_ordering(self) -> None:
         """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
         by rows in the events stream
         """
@@ -243,7 +248,9 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
         )
 
-    def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
+    def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
+        self,
+    ) -> None:
         """Check that current_state invalidation happens correctly with multiple events
         in the persistence batch.
 
@@ -283,11 +290,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
         )
         msg, msgctx = self.build_event()
-        self.get_success(
-            self._storage_controllers.persistence.persist_events(
-                [(j2, j2ctx), (msg, msgctx)]
-            )
-        )
+        self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
         self.replicate()
         assert j2.internal_metadata.stream_ordering is not None
 
@@ -339,7 +342,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
 
     event_id = 0
 
-    def persist(self, backfill=False, **kwargs) -> FrozenEvent:
+    def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
         """
         Returns:
             The event that was persisted.
@@ -348,32 +351,28 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
 
         if backfill:
             self.get_success(
-                self._storage_controllers.persistence.persist_events(
-                    [(event, context)], backfilled=True
-                )
+                self.persistance.persist_events([(event, context)], backfilled=True)
             )
         else:
-            self.get_success(
-                self._storage_controllers.persistence.persist_event(event, context)
-            )
+            self.get_success(self.persistance.persist_event(event, context))
 
         return event
 
     def build_event(
         self,
-        sender=USER_ID,
-        room_id=ROOM_ID,
-        type="m.room.message",
-        key=None,
+        sender: str = USER_ID,
+        room_id: str = ROOM_ID,
+        type: str = "m.room.message",
+        key: Optional[str] = None,
         internal: Optional[dict] = None,
-        depth=None,
-        prev_events: Optional[list] = None,
-        auth_events: Optional[list] = None,
-        prev_state: Optional[list] = None,
-        redacts=None,
+        depth: Optional[int] = None,
+        prev_events: Optional[List[Tuple[str, dict]]] = None,
+        auth_events: Optional[List[str]] = None,
+        prev_state: Optional[List[str]] = None,
+        redacts: Optional[str] = None,
         push_actions: Iterable = frozenset(),
-        **content,
-    ):
+        **content: object,
+    ) -> Tuple[EventBase, EventContext]:
         prev_events = prev_events or []
         auth_events = auth_events or []
         prev_state = prev_state or []
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 50fbff5f32..01df1be047 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase
 
 
 class AccountDataStreamTestCase(BaseStreamTestCase):
-    def test_update_function_room_account_data_limit(self):
+    def test_update_function_room_account_data_limit(self) -> None:
         """Test replication with many room account data updates"""
         store = self.hs.get_datastores().main
 
@@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_global_account_data_limit(self):
+    def test_update_function_global_account_data_limit(self) -> None:
         """Test replication with many global account data updates"""
         store = self.hs.get_datastores().main
 
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 641a94133b..043dbe76af 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List, Optional
+from typing import Any, List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
@@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import (
 )
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.replication._base import BaseStreamTestCase
 from tests.test_utils.event_injection import inject_event, inject_member_event
@@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
         self.user_id = self.register_user("u1", "pass")
         self.user_tok = self.login("u1", "pass")
@@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.room_id = self.helper.create_room_as(tok=self.user_tok)
         self.test_handler.received_rdata_rows.clear()
 
-    def test_update_function_event_row_limit(self):
+    def test_update_function_event_row_limit(self) -> None:
         """Test replication with many non-state events
 
         Checks that all events are correctly replicated when there are lots of
@@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_huge_state_change(self):
+    def test_update_function_huge_state_change(self) -> None:
         """Test replication with many state events
 
         Ensures that all events are correctly replicated when there are lots of
@@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             # "None" indicates the state has been deleted
             self.assertIsNone(sr.event_id)
 
-    def test_update_function_state_row_limit(self):
+    def test_update_function_state_row_limit(self) -> None:
         """Test replication with many state events over several stream ids."""
 
         # we want to generate lots of state changes, but for this test, we want to
@@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_backwards_stream_id(self):
+    def test_backwards_stream_id(self) -> None:
         """
         Test that RDATA that comes after the current position should be discarded.
         """
@@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
     event_count = 0
 
     def _inject_test_event(
-        self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
+        self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
     ) -> EventBase:
         if sender is None:
             sender = self.user_id
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
index bcb82c9c80..cdbdfaf057 100644
--- a/tests/replication/tcp/streams/test_federation.py
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
         config["federation_sender_instances"] = ["federation_sender1"]
         return config
 
-    def test_catchup(self):
+    def test_catchup(self) -> None:
         """Basic test of catchup on reconnect
 
         Makes sure that updates sent while we are offline are received later.
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index 2c10eab4db..38b5020ce0 100644
--- a/tests/replication/tcp/streams/test_partial_state.py
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
     hijack_auth = True
     user_id = "@bob:test"
 
-    def setUp(self):
+    def setUp(self) -> None:
         super().setUp()
         self.store = self.hs.get_datastores().main
 
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 9a229dd23f..68de5d1cc2 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -27,10 +27,11 @@ ROOM_ID_2 = "!foo:blue"
 
 
 class TypingStreamTestCase(BaseStreamTestCase):
-    def _build_replication_data_handler(self):
-        return Mock(wraps=super()._build_replication_data_handler())
+    def _build_replication_data_handler(self) -> Mock:
+        self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
+        return self.mock_handler
 
-    def test_typing(self):
+    def test_typing(self) -> None:
         typing = self.hs.get_typing_handler()
 
         self.reconnect()
@@ -43,8 +44,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
         request = self.handle_http_replication_attempt()
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -54,11 +55,11 @@ class TypingStreamTestCase(BaseStreamTestCase):
         # Now let's disconnect and insert some data.
         self.disconnect()
 
-        self.test_handler.on_rdata.reset_mock()
+        self.mock_handler.on_rdata.reset_mock()
 
         typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
 
-        self.test_handler.on_rdata.assert_not_called()
+        self.mock_handler.on_rdata.assert_not_called()
 
         self.reconnect()
         self.pump(0.1)
@@ -71,15 +72,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
         assert request.args is not None
         self.assertEqual(int(request.args[b"from_token"][0]), token)
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
         self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([], row.user_ids)
 
-    def test_reset(self):
+    def test_reset(self) -> None:
         """
         Test what happens when a typing stream resets.
 
@@ -98,8 +99,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
         request = self.handle_http_replication_attempt()
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -134,15 +135,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
         # Reset the test code.
-        self.test_handler.on_rdata.reset_mock()
-        self.test_handler.on_rdata.assert_not_called()
+        self.mock_handler.on_rdata.reset_mock()
+        self.mock_handler.on_rdata.assert_not_called()
 
         # Push additional data.
         typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
         self.reactor.advance(0)
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index cca7ebb719..5d6b72b16d 100644
--- a/tests/replication/tcp/test_commands.py
+++ b/tests/replication/tcp/test_commands.py
@@ -21,12 +21,12 @@ from tests.unittest import TestCase
 
 
 class ParseCommandTestCase(TestCase):
-    def test_parse_one_word_command(self):
+    def test_parse_one_word_command(self) -> None:
         line = "REPLICATE"
         cmd = parse_command_from_line(line)
         self.assertIsInstance(cmd, ReplicateCommand)
 
-    def test_parse_rdata(self):
+    def test_parse_rdata(self) -> None:
         line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
         cmd = parse_command_from_line(line)
         assert isinstance(cmd, RdataCommand)
@@ -34,7 +34,7 @@ class ParseCommandTestCase(TestCase):
         self.assertEqual(cmd.instance_name, "master")
         self.assertEqual(cmd.token, 6287863)
 
-    def test_parse_rdata_batch(self):
+    def test_parse_rdata_batch(self) -> None:
         line = 'RDATA presence master batch ["@foo:example.com", "online"]'
         cmd = parse_command_from_line(line)
         assert isinstance(cmd, RdataCommand)
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
index 545f11acd1..b75fc05fd5 100644
--- a/tests/replication/tcp/test_remote_server_up.py
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -16,15 +16,17 @@ from typing import Tuple
 
 from twisted.internet.address import IPv4Address
 from twisted.internet.interfaces import IProtocol
-from twisted.test.proto_helpers import StringTransport
+from twisted.test.proto_helpers import MemoryReactor, StringTransport
 
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
 
 class RemoteServerUpTestCase(HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.factory = ReplicationStreamProtocolFactory(hs)
 
     def _make_client(self) -> Tuple[IProtocol, StringTransport]:
@@ -40,7 +42,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
 
         return proto, transport
 
-    def test_relay(self):
+    def test_relay(self) -> None:
         """Test that Synapse will relay REMOTE_SERVER_UP commands to all
         other connections, but not the one that sent it.
         """
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index 5d7a89e0c7..98602371e4 100644
--- a/tests/replication/test_auth.py
+++ b/tests/replication/test_auth.py
@@ -13,7 +13,11 @@
 # limitations under the License.
 import logging
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.rest.client import register
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import FakeChannel, make_request
@@ -27,7 +31,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
 
     servlets = [register.register_servlets]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         # This isn't a real configuration option but is used to provide the main
         # homeserver and worker homeserver different options.
@@ -77,7 +81,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
             {"auth": {"session": session, "type": "m.login.dummy"}},
         )
 
-    def test_no_auth(self):
+    def test_no_auth(self) -> None:
         """With no authentication the request should finish."""
         channel = self._test_register()
         self.assertEqual(channel.code, 200)
@@ -86,7 +90,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(channel.json_body["user_id"], "@user:test")
 
     @override_config({"main_replication_secret": "my-secret"})
-    def test_missing_auth(self):
+    def test_missing_auth(self) -> None:
         """If the main process expects a secret that is not provided, an error results."""
         channel = self._test_register()
         self.assertEqual(channel.code, 500)
@@ -97,13 +101,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
             "worker_replication_secret": "wrong-secret",
         }
     )
-    def test_unauthorized(self):
+    def test_unauthorized(self) -> None:
         """If the main process receives the wrong secret, an error results."""
         channel = self._test_register()
         self.assertEqual(channel.code, 500)
 
     @override_config({"worker_replication_secret": "my-secret"})
-    def test_authorized(self):
+    def test_authorized(self) -> None:
         """The request should finish when the worker provides the authentication header."""
         channel = self._test_register()
         self.assertEqual(channel.code, 200)
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index eb5b376534..eca5033761 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -33,7 +33,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         config["worker_replication_http_port"] = "8765"
         return config
 
-    def test_register_single_worker(self):
+    def test_register_single_worker(self) -> None:
         """Test that registration works when using a single generic worker."""
         worker_hs = self.make_worker_hs("synapse.app.generic_worker")
         site = self._hs_to_site[worker_hs]
@@ -63,7 +63,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         # We're given a registered user.
         self.assertEqual(channel_2.json_body["user_id"], "@user:test")
 
-    def test_register_multi_worker(self):
+    def test_register_multi_worker(self) -> None:
         """Test that registration works when using multiple generic workers."""
         worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
         worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 63b1dd40b5..12668b34c5 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -14,10 +14,14 @@
 
 from unittest import mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.app.generic_worker import GenericWorkerServer
 from synapse.replication.tcp.commands import FederationAckCommand
 from synapse.replication.tcp.protocol import IReplicationConnection
 from synapse.replication.tcp.streams.federation import FederationStream
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
@@ -30,12 +34,10 @@ class FederationAckTestCase(HomeserverTestCase):
         config["federation_sender_instances"] = ["federation_sender1"]
         return config
 
-    def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
-
-        return hs
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        return self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
 
-    def test_federation_ack_sent(self):
+    def test_federation_ack_sent(self) -> None:
         """A FEDERATION_ACK should be sent back after each RDATA federation
 
         This test checks that the federation sender is correctly sending back
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index c28073b8f7..89380e25b5 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -40,7 +40,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         room.register_servlets,
     ]
 
-    def test_send_event_single_sender(self):
+    def test_send_event_single_sender(self) -> None:
         """Test that using a single federation sender worker correctly sends a
         new event.
         """
@@ -71,7 +71,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
         self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
 
-    def test_send_event_sharded(self):
+    def test_send_event_sharded(self) -> None:
         """Test that using two federation sender workers correctly sends
         new events.
         """
@@ -138,7 +138,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         self.assertTrue(sent_on_1)
         self.assertTrue(sent_on_2)
 
-    def test_send_typing_sharded(self):
+    def test_send_typing_sharded(self) -> None:
         """Test that using two federation sender workers correctly sends
         new typing EDUs.
         """
@@ -215,7 +215,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         self.assertTrue(sent_on_1)
         self.assertTrue(sent_on_2)
 
-    def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+    def create_room_with_remote_server(
+        self, user: str, token: str, remote_server: str = "other_server"
+    ) -> str:
         room = self.helper.create_room_as(user, tok=token)
         store = self.hs.get_datastores().main
         federation = self.hs.get_federation_event_handler()
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
index b93cae67d3..9c4fbda71b 100644
--- a/tests/replication/test_module_cache_invalidation.py
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -39,7 +39,7 @@ class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
         synapse.rest.admin.register_servlets,
     ]
 
-    def test_module_cache_full_invalidation(self):
+    def test_module_cache_full_invalidation(self) -> None:
         main_cache = TestCache()
         self.hs.get_module_api().register_cached_function(main_cache.cached_function)
 
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 96cdf2c45b..1527b4a82d 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -18,12 +18,14 @@ from typing import Optional, Tuple
 from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
 from twisted.internet.protocol import Factory
 from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
+from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.http import HTTPChannel
 from twisted.web.server import Request
 
 from synapse.rest import admin
 from synapse.rest.client import login
 from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
 from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = self.register_user("user", "pass")
         self.access_token = self.login("user", "pass")
 
         self.reactor.lookups["example.com"] = "1.2.3.4"
 
-    def default_config(self):
+    def default_config(self) -> dict:
         conf = super().default_config()
         conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
         return conf
@@ -122,7 +124,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
 
         return channel, request
 
-    def test_basic(self):
+    def test_basic(self) -> None:
         """Test basic fetching of remote media from a single worker."""
         hs1 = self.make_worker_hs("synapse.app.generic_worker")
 
@@ -138,7 +140,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.result["body"], b"Hello!")
 
-    def test_download_simple_file_race(self):
+    def test_download_simple_file_race(self) -> None:
         """Test that fetching remote media from two different processes at the
         same time works.
         """
@@ -177,7 +179,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         # We expect only one new file to have been persisted.
         self.assertEqual(start_count + 1, self._count_remote_media())
 
-    def test_download_image_race(self):
+    def test_download_image_race(self) -> None:
         """Test that fetching remote *images* from two different processes at
         the same time works.
 
@@ -229,7 +231,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         return sum(len(files) for _, _, files in os.walk(path))
 
 
-def get_connection_factory():
+def get_connection_factory() -> TestServerTLSConnectionFactory:
     # this needs to happen once, but not until we are ready to run the first test
     global test_server_connection_factory
     if test_server_connection_factory is None:
@@ -263,6 +265,6 @@ def _build_test_server(
     return server_tls_factory.buildProtocol(None)
 
 
-def _log_request(request):
+def _log_request(request: Request) -> None:
     """Implements Factory.log, which is expected by Request.finish"""
     logger.info("Completed request %s", request)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index ca18ad6553..9345cfbeb2 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -15,9 +15,12 @@ import logging
 from unittest.mock import Mock
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 
@@ -33,12 +36,12 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # Register a user who sends a message that we'll get notified about
         self.other_user_id = self.register_user("otheruser", "pass")
         self.other_access_token = self.login("otheruser", "pass")
 
-    def _create_pusher_and_send_msg(self, localpart):
+    def _create_pusher_and_send_msg(self, localpart: str) -> str:
         # Create a user that will get push notifications
         user_id = self.register_user(localpart, "pass")
         access_token = self.login(localpart, "pass")
@@ -79,7 +82,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
 
         return event_id
 
-    def test_send_push_single_worker(self):
+    def test_send_push_single_worker(self) -> None:
         """Test that registration works when using a pusher worker."""
         http_client_mock = Mock(spec_set=["post_json_get_json"])
         http_client_mock.post_json_get_json.side_effect = (
@@ -109,7 +112,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
             ],
         )
 
-    def test_send_push_multiple_workers(self):
+    def test_send_push_multiple_workers(self) -> None:
         """Test that registration works when using sharded pusher workers."""
         http_client_mock1 = Mock(spec_set=["post_json_get_json"])
         http_client_mock1.post_json_get_json.side_effect = (
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 541d390286..7f9cc67e73 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -14,9 +14,13 @@
 import logging
 from unittest.mock import patch
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.rest import admin
 from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.util import Clock
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import make_request
@@ -34,7 +38,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         sync.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # Register a user who sends a message that we'll get notified about
         self.other_user_id = self.register_user("otheruser", "pass")
         self.other_access_token = self.login("otheruser", "pass")
@@ -42,7 +46,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         self.room_creator = self.hs.get_room_creation_handler()
         self.store = hs.get_datastores().main
 
-    def default_config(self):
+    def default_config(self) -> dict:
         conf = super().default_config()
         conf["stream_writers"] = {"events": ["worker1", "worker2"]}
         conf["instance_map"] = {
@@ -51,7 +55,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         }
         return conf
 
-    def _create_room(self, room_id: str, user_id: str, tok: str):
+    def _create_room(self, room_id: str, user_id: str, tok: str) -> None:
         """Create a room with given room_id"""
 
         # We control the room ID generation by patching out the
@@ -62,7 +66,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
             mock.side_effect = lambda: room_id
             self.helper.create_room_as(user_id, tok=tok)
 
-    def test_basic(self):
+    def test_basic(self) -> None:
         """Simple test to ensure that multiple rooms can be created and joined,
         and that different rooms get handled by different instances.
         """
@@ -112,7 +116,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         self.assertTrue(persisted_on_1)
         self.assertTrue(persisted_on_2)
 
-    def test_vector_clock_token(self):
+    def test_vector_clock_token(self) -> None:
         """Tests that using a stream token with a vector clock component works
         correctly with basic /sync and /messages usage.
         """