summary refs log tree commit diff
path: root/tests/replication/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r--tests/replication/_base.py70
1 files changed, 40 insertions, 30 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 6a7174b333..46a8e2013e 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -16,7 +16,9 @@ from collections import defaultdict
 from typing import Any, Dict, List, Optional, Set, Tuple
 
 from twisted.internet.address import IPv4Address
-from twisted.internet.protocol import Protocol
+from twisted.internet.protocol import Protocol, connectionDone
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import GenericWorkerServer
@@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
     if not hiredis:
         skip = "Requires hiredis"
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
         self.streamer = hs.get_replication_streamer()
@@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
             repl_handler,
         )
 
-        self._client_transport = None
-        self._server_transport = None
+        self._client_transport: Optional[FakeTransport] = None
+        self._server_transport: Optional[FakeTransport] = None
 
     def create_resource_dict(self) -> Dict[str, Resource]:
         d = super().create_resource_dict()
@@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         config["worker_replication_http_port"] = "8765"
         return config
 
-    def _build_replication_data_handler(self):
+    def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
         return TestReplicationDataHandler(self.worker_hs)
 
-    def reconnect(self):
+    def reconnect(self) -> None:
         if self._client_transport:
             self.client.close()
 
@@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._server_transport = FakeTransport(self.client, self.reactor)
         self.server.makeConnection(self._server_transport)
 
-    def disconnect(self):
+    def disconnect(self) -> None:
         if self._client_transport:
             self._client_transport = None
             self.client.close()
@@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
             self._server_transport = None
             self.server.close()
 
-    def replicate(self):
+    def replicate(self) -> None:
         """Tell the master side of replication that something has happened, and then
         wait for the replication to occur.
         """
@@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         requests: List[SynapseRequest] = []
         real_request_factory = channel.requestFactory
 
-        def request_factory(*args, **kwargs):
+        def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
             request = real_request_factory(*args, **kwargs)
             requests.append(request)
             return request
@@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
     def assert_request_is_get_repl_stream_updates(
         self, request: SynapseRequest, stream_name: str
-    ):
+    ) -> None:
         """Asserts that the given request is a HTTP replication request for
         fetching updates for given stream.
         """
@@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         base["redis"] = {"enabled": True}
         return base
 
-    def setUp(self):
+    def setUp(self) -> None:
         super().setUp()
 
         # build a replication server
@@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             lambda: self._handle_http_replication_attempt(self.hs, 8765),
         )
 
-    def create_test_resource(self):
+    def create_test_resource(self) -> ReplicationRestResource:
         """Overrides `HomeserverTestCase.create_test_resource`."""
         # We override this so that it automatically registers all the HTTP
         # replication servlets, without having to explicitly do that in all
@@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         return resource
 
     def make_worker_hs(
-        self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
+        self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
     ) -> HomeServer:
         """Make a new worker HS instance, correctly connecting replcation
         stream to the master HS.
@@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         config["worker_replication_http_port"] = "8765"
         return config
 
-    def replicate(self):
+    def replicate(self) -> None:
         """Tell the master side of replication that something has happened, and then
         wait for the replication to occur.
         """
         self.streamer.on_notifier_poke()
         self.pump()
 
-    def _handle_http_replication_attempt(self, hs, repl_port):
+    def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
         """Handles a connection attempt to the given HS replication HTTP
         listener on the given port.
         """
@@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # inside `connecTCP` before the connection has been passed back to the
         # code that requested the TCP connection.
 
-    def connect_any_redis_attempts(self):
+    def connect_any_redis_attempts(self) -> None:
         """If redis is enabled we need to deal with workers connecting to a
         redis server. We don't want to use a real Redis server so we use a
         fake one.
@@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             self.assertEqual(host, "localhost")
             self.assertEqual(port, 6379)
 
-            client_protocol = client_factory.buildProtocol(None)
-            server_protocol = self._redis_server.buildProtocol(None)
+            client_address = IPv4Address("TCP", "127.0.0.1", 6379)
+            client_protocol = client_factory.buildProtocol(client_address)
+
+            server_address = IPv4Address("TCP", host, port)
+            server_protocol = self._redis_server.buildProtocol(server_address)
 
             client_to_server_transport = FakeTransport(
                 server_protocol, self.reactor, client_protocol
@@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
         # list of received (stream_name, token, row) tuples
         self.received_rdata_rows: List[Tuple[str, int, Any]] = []
 
-    async def on_rdata(self, stream_name, instance_name, token, rows):
+    async def on_rdata(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ) -> None:
         await super().on_rdata(stream_name, instance_name, token, rows)
         for r in rows:
             self.received_rdata_rows.append((stream_name, token, r))
@@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
 class FakeRedisPubSubServer:
     """A fake Redis server for pub/sub."""
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._subscribers_by_channel: Dict[
             bytes, Set["FakeRedisPubSubProtocol"]
         ] = defaultdict(set)
 
-    def add_subscriber(self, conn, channel: bytes):
+    def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
         """A connection has called SUBSCRIBE"""
         self._subscribers_by_channel[channel].add(conn)
 
-    def remove_subscriber(self, conn):
+    def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
         """A connection has lost connection"""
         for subscribers in self._subscribers_by_channel.values():
             subscribers.discard(conn)
 
-    def publish(self, conn, channel: bytes, msg) -> int:
+    def publish(
+        self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
+    ) -> int:
         """A connection want to publish a message to subscribers."""
         for sub in self._subscribers_by_channel[channel]:
             sub.send(["message", channel, msg])
 
         return len(self._subscribers_by_channel)
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
         return FakeRedisPubSubProtocol(self)
 
 
@@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
         self._server = server
         self._reader = hiredis.Reader()
 
-    def dataReceived(self, data):
+    def dataReceived(self, data: bytes) -> None:
         self._reader.feed(data)
 
         # We might get multiple messages in one packet.
@@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
 
             self.handle_command(msg[0], *msg[1:])
 
-    def handle_command(self, command, *args):
+    def handle_command(self, command: bytes, *args: bytes) -> None:
         """Received a Redis command from the client."""
 
         # We currently only support pub/sub.
@@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
             self.send("PONG")
 
         else:
-            raise Exception(f"Unknown command: {command}")
+            raise Exception(f"Unknown command: {command!r}")
 
-    def send(self, msg):
+    def send(self, msg: object) -> None:
         """Send a message back to the client."""
         assert self.transport is not None
 
@@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
         self.transport.write(raw)
         self.transport.flush()
 
-    def encode(self, obj):
+    def encode(self, obj: object) -> str:
         """Encode an object to its Redis format.
 
         Supports: strings/bytes, integers and list/tuples.
@@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
 
         raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
 
-    def connectionLost(self, reason):
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         self._server.remove_subscriber(self)