diff options
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r-- | tests/replication/_base.py | 70 |
1 files changed, 40 insertions, 30 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 6a7174b333..46a8e2013e 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -16,7 +16,9 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, Set, Tuple from twisted.internet.address import IPv4Address -from twisted.internet.protocol import Protocol +from twisted.internet.protocol import Protocol, connectionDone +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource from synapse.app.generic_worker import GenericWorkerServer @@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import ( ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeTransport @@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): if not hiredis: skip = "Requires hiredis" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) self.streamer = hs.get_replication_streamer() @@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): repl_handler, ) - self._client_transport = None - self._server_transport = None + self._client_transport: Optional[FakeTransport] = None + self._server_transport: Optional[FakeTransport] = None def create_resource_dict(self) -> Dict[str, Resource]: d = super().create_resource_dict() @@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): config["worker_replication_http_port"] = "8765" return config - def _build_replication_data_handler(self): + def _build_replication_data_handler(self) -> "TestReplicationDataHandler": return TestReplicationDataHandler(self.worker_hs) - def reconnect(self): + def reconnect(self) -> None: if self._client_transport: self.client.close() @@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._server_transport = FakeTransport(self.client, self.reactor) self.server.makeConnection(self._server_transport) - def disconnect(self): + def disconnect(self) -> None: if self._client_transport: self._client_transport = None self.client.close() @@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._server_transport = None self.server.close() - def replicate(self): + def replicate(self) -> None: """Tell the master side of replication that something has happened, and then wait for the replication to occur. """ @@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): requests: List[SynapseRequest] = [] real_request_factory = channel.requestFactory - def request_factory(*args, **kwargs): + def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest: request = real_request_factory(*args, **kwargs) requests.append(request) return request @@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): def assert_request_is_get_repl_stream_updates( self, request: SynapseRequest, stream_name: str - ): + ) -> None: """Asserts that the given request is a HTTP replication request for fetching updates for given stream. """ @@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): base["redis"] = {"enabled": True} return base - def setUp(self): + def setUp(self) -> None: super().setUp() # build a replication server @@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): lambda: self._handle_http_replication_attempt(self.hs, 8765), ) - def create_test_resource(self): + def create_test_resource(self) -> ReplicationRestResource: """Overrides `HomeserverTestCase.create_test_resource`.""" # We override this so that it automatically registers all the HTTP # replication servlets, without having to explicitly do that in all @@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): return resource def make_worker_hs( - self, worker_app: str, extra_config: Optional[dict] = None, **kwargs + self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any ) -> HomeServer: """Make a new worker HS instance, correctly connecting replcation stream to the master HS. @@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config["worker_replication_http_port"] = "8765" return config - def replicate(self): + def replicate(self) -> None: """Tell the master side of replication that something has happened, and then wait for the replication to occur. """ self.streamer.on_notifier_poke() self.pump() - def _handle_http_replication_attempt(self, hs, repl_port): + def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None: """Handles a connection attempt to the given HS replication HTTP listener on the given port. """ @@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # inside `connecTCP` before the connection has been passed back to the # code that requested the TCP connection. - def connect_any_redis_attempts(self): + def connect_any_redis_attempts(self) -> None: """If redis is enabled we need to deal with workers connecting to a redis server. We don't want to use a real Redis server so we use a fake one. @@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(host, "localhost") self.assertEqual(port, 6379) - client_protocol = client_factory.buildProtocol(None) - server_protocol = self._redis_server.buildProtocol(None) + client_address = IPv4Address("TCP", "127.0.0.1", 6379) + client_protocol = client_factory.buildProtocol(client_address) + + server_address = IPv4Address("TCP", host, port) + server_protocol = self._redis_server.buildProtocol(server_address) client_to_server_transport = FakeTransport( server_protocol, self.reactor, client_protocol @@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler): # list of received (stream_name, token, row) tuples self.received_rdata_rows: List[Tuple[str, int, Any]] = [] - async def on_rdata(self, stream_name, instance_name, token, rows): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ) -> None: await super().on_rdata(stream_name, instance_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) @@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler): class FakeRedisPubSubServer: """A fake Redis server for pub/sub.""" - def __init__(self): + def __init__(self) -> None: self._subscribers_by_channel: Dict[ bytes, Set["FakeRedisPubSubProtocol"] ] = defaultdict(set) - def add_subscriber(self, conn, channel: bytes): + def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None: """A connection has called SUBSCRIBE""" self._subscribers_by_channel[channel].add(conn) - def remove_subscriber(self, conn): + def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None: """A connection has lost connection""" for subscribers in self._subscribers_by_channel.values(): subscribers.discard(conn) - def publish(self, conn, channel: bytes, msg) -> int: + def publish( + self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object + ) -> int: """A connection want to publish a message to subscribers.""" for sub in self._subscribers_by_channel[channel]: sub.send(["message", channel, msg]) return len(self._subscribers_by_channel) - def buildProtocol(self, addr): + def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol": return FakeRedisPubSubProtocol(self) @@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol): self._server = server self._reader = hiredis.Reader() - def dataReceived(self, data): + def dataReceived(self, data: bytes) -> None: self._reader.feed(data) # We might get multiple messages in one packet. @@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol): self.handle_command(msg[0], *msg[1:]) - def handle_command(self, command, *args): + def handle_command(self, command: bytes, *args: bytes) -> None: """Received a Redis command from the client.""" # We currently only support pub/sub. @@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol): self.send("PONG") else: - raise Exception(f"Unknown command: {command}") + raise Exception(f"Unknown command: {command!r}") - def send(self, msg): + def send(self, msg: object) -> None: """Send a message back to the client.""" assert self.transport is not None @@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol): self.transport.write(raw) self.transport.flush() - def encode(self, obj): + def encode(self, obj: object) -> str: """Encode an object to its Redis format. Supports: strings/bytes, integers and list/tuples. @@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol): raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) - def connectionLost(self, reason): + def connectionLost(self, reason: Failure = connectionDone) -> None: self._server.remove_subscriber(self) |