diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 6a7174b333..46a8e2013e 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -16,7 +16,9 @@ from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address
-from twisted.internet.protocol import Protocol
+from twisted.internet.protocol import Protocol, connectionDone
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
@@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
@@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
@@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
repl_handler,
)
- self._client_transport = None
- self._server_transport = None
+ self._client_transport: Optional[FakeTransport] = None
+ self._server_transport: Optional[FakeTransport] = None
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
@@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
- def _build_replication_data_handler(self):
+ def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
return TestReplicationDataHandler(self.worker_hs)
- def reconnect(self):
+ def reconnect(self) -> None:
if self._client_transport:
self.client.close()
@@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)
- def disconnect(self):
+ def disconnect(self) -> None:
if self._client_transport:
self._client_transport = None
self.client.close()
@@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = None
self.server.close()
- def replicate(self):
+ def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
@@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
requests: List[SynapseRequest] = []
real_request_factory = channel.requestFactory
- def request_factory(*args, **kwargs):
+ def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
request = real_request_factory(*args, **kwargs)
requests.append(request)
return request
@@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
- ):
+ ) -> None:
"""Asserts that the given request is a HTTP replication request for
fetching updates for given stream.
"""
@@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
base["redis"] = {"enabled": True}
return base
- def setUp(self):
+ def setUp(self) -> None:
super().setUp()
# build a replication server
@@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
- def create_test_resource(self):
+ def create_test_resource(self) -> ReplicationRestResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
@@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return resource
def make_worker_hs(
- self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
+ self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
@@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
- def replicate(self):
+ def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()
- def _handle_http_replication_attempt(self, hs, repl_port):
+ def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
"""Handles a connection attempt to the given HS replication HTTP
listener on the given port.
"""
@@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
- def connect_any_redis_attempts(self):
+ def connect_any_redis_attempts(self) -> None:
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
@@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
- client_protocol = client_factory.buildProtocol(None)
- server_protocol = self._redis_server.buildProtocol(None)
+ client_address = IPv4Address("TCP", "127.0.0.1", 6379)
+ client_protocol = client_factory.buildProtocol(client_address)
+
+ server_address = IPv4Address("TCP", host, port)
+ server_protocol = self._redis_server.buildProtocol(server_address)
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
@@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
# list of received (stream_name, token, row) tuples
self.received_rdata_rows: List[Tuple[str, int, Any]] = []
- async def on_rdata(self, stream_name, instance_name, token, rows):
+ async def on_rdata(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ) -> None:
await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
@@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
- def __init__(self):
+ def __init__(self) -> None:
self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)
- def add_subscriber(self, conn, channel: bytes):
+ def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE"""
self._subscribers_by_channel[channel].add(conn)
- def remove_subscriber(self, conn):
+ def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)
- def publish(self, conn, channel: bytes, msg) -> int:
+ def publish(
+ self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
+ ) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])
return len(self._subscribers_by_channel)
- def buildProtocol(self, addr):
+ def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
return FakeRedisPubSubProtocol(self)
@@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
self._server = server
self._reader = hiredis.Reader()
- def dataReceived(self, data):
+ def dataReceived(self, data: bytes) -> None:
self._reader.feed(data)
# We might get multiple messages in one packet.
@@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.handle_command(msg[0], *msg[1:])
- def handle_command(self, command, *args):
+ def handle_command(self, command: bytes, *args: bytes) -> None:
"""Received a Redis command from the client."""
# We currently only support pub/sub.
@@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
self.send("PONG")
else:
- raise Exception(f"Unknown command: {command}")
+ raise Exception(f"Unknown command: {command!r}")
- def send(self, msg):
+ def send(self, msg: object) -> None:
"""Send a message back to the client."""
assert self.transport is not None
@@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.transport.write(raw)
self.transport.flush()
- def encode(self, obj):
+ def encode(self, obj: object) -> str:
"""Encode an object to its Redis format.
Supports: strings/bytes, integers and list/tuples.
@@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
- def connectionLost(self, reason):
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
self._server.remove_subscriber(self)
|