diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 6a7174b333..46a8e2013e 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -16,7 +16,9 @@ from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address
-from twisted.internet.protocol import Protocol
+from twisted.internet.protocol import Protocol, connectionDone
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
@@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
@@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
@@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
repl_handler,
)
- self._client_transport = None
- self._server_transport = None
+ self._client_transport: Optional[FakeTransport] = None
+ self._server_transport: Optional[FakeTransport] = None
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
@@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
- def _build_replication_data_handler(self):
+ def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
return TestReplicationDataHandler(self.worker_hs)
- def reconnect(self):
+ def reconnect(self) -> None:
if self._client_transport:
self.client.close()
@@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)
- def disconnect(self):
+ def disconnect(self) -> None:
if self._client_transport:
self._client_transport = None
self.client.close()
@@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = None
self.server.close()
- def replicate(self):
+ def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
@@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
requests: List[SynapseRequest] = []
real_request_factory = channel.requestFactory
- def request_factory(*args, **kwargs):
+ def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
request = real_request_factory(*args, **kwargs)
requests.append(request)
return request
@@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
- ):
+ ) -> None:
"""Asserts that the given request is a HTTP replication request for
fetching updates for given stream.
"""
@@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
base["redis"] = {"enabled": True}
return base
- def setUp(self):
+ def setUp(self) -> None:
super().setUp()
# build a replication server
@@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
- def create_test_resource(self):
+ def create_test_resource(self) -> ReplicationRestResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
@@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return resource
def make_worker_hs(
- self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
+ self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
@@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
- def replicate(self):
+ def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()
- def _handle_http_replication_attempt(self, hs, repl_port):
+ def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
"""Handles a connection attempt to the given HS replication HTTP
listener on the given port.
"""
@@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
- def connect_any_redis_attempts(self):
+ def connect_any_redis_attempts(self) -> None:
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
@@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
- client_protocol = client_factory.buildProtocol(None)
- server_protocol = self._redis_server.buildProtocol(None)
+ client_address = IPv4Address("TCP", "127.0.0.1", 6379)
+ client_protocol = client_factory.buildProtocol(client_address)
+
+ server_address = IPv4Address("TCP", host, port)
+ server_protocol = self._redis_server.buildProtocol(server_address)
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
@@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
# list of received (stream_name, token, row) tuples
self.received_rdata_rows: List[Tuple[str, int, Any]] = []
- async def on_rdata(self, stream_name, instance_name, token, rows):
+ async def on_rdata(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ) -> None:
await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
@@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
- def __init__(self):
+ def __init__(self) -> None:
self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)
- def add_subscriber(self, conn, channel: bytes):
+ def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE"""
self._subscribers_by_channel[channel].add(conn)
- def remove_subscriber(self, conn):
+ def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)
- def publish(self, conn, channel: bytes, msg) -> int:
+ def publish(
+ self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
+ ) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])
return len(self._subscribers_by_channel)
- def buildProtocol(self, addr):
+ def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
return FakeRedisPubSubProtocol(self)
@@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
self._server = server
self._reader = hiredis.Reader()
- def dataReceived(self, data):
+ def dataReceived(self, data: bytes) -> None:
self._reader.feed(data)
# We might get multiple messages in one packet.
@@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.handle_command(msg[0], *msg[1:])
- def handle_command(self, command, *args):
+ def handle_command(self, command: bytes, *args: bytes) -> None:
"""Received a Redis command from the client."""
# We currently only support pub/sub.
@@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
self.send("PONG")
else:
- raise Exception(f"Unknown command: {command}")
+ raise Exception(f"Unknown command: {command!r}")
- def send(self, msg):
+ def send(self, msg: object) -> None:
"""Send a message back to the client."""
assert self.transport is not None
@@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.transport.write(raw)
self.transport.flush()
- def encode(self, obj):
+ def encode(self, obj: object) -> str:
"""Encode an object to its Redis format.
Supports: strings/bytes, integers and list/tuples.
@@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
- def connectionLost(self, reason):
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
self._server.remove_subscriber(self)
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index e03d9b4cc0..9be11ab802 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -74,7 +74,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""
- def create_test_resource(self):
+ def create_test_resource(self) -> JsonResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
resource = JsonResource(self.hs)
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index c5705256e6..4c9b494344 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -13,35 +13,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Iterable, Optional
from unittest.mock import Mock
-from tests.replication._base import BaseStreamTestCase
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.server import HomeServer
+from synapse.util import Clock
-class BaseSlavedStoreTestCase(BaseStreamTestCase):
- def make_homeserver(self, reactor, clock):
+from tests.replication._base import BaseStreamTestCase
- hs = self.setup_test_homeserver(federation_client=Mock())
- return hs
+class BaseSlavedStoreTestCase(BaseStreamTestCase):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ return self.setup_test_homeserver(federation_client=Mock())
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.reconnect()
self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
- self._storage_controllers = hs.get_storage_controllers()
+ persistence = hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistance = persistence
- def replicate(self):
+ def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump(0.1)
- def check(self, method, args, expected_result=None):
+ def check(
+ self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
+ ) -> None:
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index dce71f7334..ddca9d696c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Iterable, Optional
+from typing import Any, Callable, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
+from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
+from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import (
NotifCounts,
RoomNotifCounts,
@@ -28,6 +32,7 @@ from synapse.storage.databases.main.event_push_actions import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
+from synapse.util import Clock
from tests.server import FakeTransport
@@ -41,19 +46,19 @@ ROOM_ID = "!room:test"
logger = logging.getLogger(__name__)
-def dict_equals(self, other):
+def dict_equals(self: EventBase, other: EventBase) -> bool:
me = encode_canonical_json(self.get_pdu_json())
them = encode_canonical_json(other.get_pdu_json())
return me == them
-def patch__eq__(cls):
+def patch__eq__(cls: object) -> Callable[[], None]:
eq = getattr(cls, "__eq__", None)
- cls.__eq__ = dict_equals
+ cls.__eq__ = dict_equals # type: ignore[assignment]
- def unpatch():
+ def unpatch() -> None:
if eq is not None:
- cls.__eq__ = eq
+ cls.__eq__ = eq # type: ignore[assignment]
return unpatch
@@ -62,14 +67,14 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = EventsWorkerStore
- def setUp(self):
+ def setUp(self) -> None:
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEqual
- self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
- return super().setUp()
+ self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
+ super().setUp()
- def prepare(self, *args, **kwargs):
- super().prepare(*args, **kwargs)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ super().prepare(reactor, clock, hs)
self.get_success(
self.master_store.store_room(
@@ -80,10 +85,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
)
- def tearDown(self):
+ def tearDown(self) -> None:
[unpatch() for unpatch in self.unpatches]
- def test_get_latest_event_ids_in_room(self):
+ def test_get_latest_event_ids_in_room(self) -> None:
create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
@@ -97,7 +102,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
- def test_redactions(self):
+ def test_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
@@ -117,7 +122,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
self.check("get_event", [msg.event_id], redacted)
- def test_backfilled_redactions(self):
+ def test_backfilled_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
@@ -139,7 +144,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
self.check("get_event", [msg.event_id], redacted)
- def test_invites(self):
+ def test_invites(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
@@ -163,7 +168,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
@parameterized.expand([(True,), (False,)])
- def test_push_actions_for_user(self, send_receipt: bool):
+ def test_push_actions_for_user(self, send_receipt: bool) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist(
@@ -219,7 +224,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
),
)
- def test_get_rooms_for_user_with_stream_ordering(self):
+ def test_get_rooms_for_user_with_stream_ordering(self) -> None:
"""Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
by rows in the events stream
"""
@@ -243,7 +248,9 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
{GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
)
- def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
+ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
+ self,
+ ) -> None:
"""Check that current_state invalidation happens correctly with multiple events
in the persistence batch.
@@ -283,11 +290,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
- self.get_success(
- self._storage_controllers.persistence.persist_events(
- [(j2, j2ctx), (msg, msgctx)]
- )
- )
+ self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
self.replicate()
assert j2.internal_metadata.stream_ordering is not None
@@ -339,7 +342,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
event_id = 0
- def persist(self, backfill=False, **kwargs) -> FrozenEvent:
+ def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
"""
Returns:
The event that was persisted.
@@ -348,32 +351,28 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
- self._storage_controllers.persistence.persist_events(
- [(event, context)], backfilled=True
- )
+ self.persistance.persist_events([(event, context)], backfilled=True)
)
else:
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ self.get_success(self.persistance.persist_event(event, context))
return event
def build_event(
self,
- sender=USER_ID,
- room_id=ROOM_ID,
- type="m.room.message",
- key=None,
+ sender: str = USER_ID,
+ room_id: str = ROOM_ID,
+ type: str = "m.room.message",
+ key: Optional[str] = None,
internal: Optional[dict] = None,
- depth=None,
- prev_events: Optional[list] = None,
- auth_events: Optional[list] = None,
- prev_state: Optional[list] = None,
- redacts=None,
+ depth: Optional[int] = None,
+ prev_events: Optional[List[Tuple[str, dict]]] = None,
+ auth_events: Optional[List[str]] = None,
+ prev_state: Optional[List[str]] = None,
+ redacts: Optional[str] = None,
push_actions: Iterable = frozenset(),
- **content,
- ):
+ **content: object,
+ ) -> Tuple[EventBase, EventContext]:
prev_events = prev_events or []
auth_events = auth_events or []
prev_state = prev_state or []
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 50fbff5f32..01df1be047 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase
class AccountDataStreamTestCase(BaseStreamTestCase):
- def test_update_function_room_account_data_limit(self):
+ def test_update_function_room_account_data_limit(self) -> None:
"""Test replication with many room account data updates"""
store = self.hs.get_datastores().main
@@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
- def test_update_function_global_account_data_limit(self):
+ def test_update_function_global_account_data_limit(self) -> None:
"""Test replication with many global account data updates"""
store = self.hs.get_datastores().main
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 641a94133b..043dbe76af 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional
+from typing import Any, List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import (
)
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.replication._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
@@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
@@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
- def test_update_function_event_row_limit(self):
+ def test_update_function_event_row_limit(self) -> None:
"""Test replication with many non-state events
Checks that all events are correctly replicated when there are lots of
@@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
- def test_update_function_huge_state_change(self):
+ def test_update_function_huge_state_change(self) -> None:
"""Test replication with many state events
Ensures that all events are correctly replicated when there are lots of
@@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
- def test_update_function_state_row_limit(self):
+ def test_update_function_state_row_limit(self) -> None:
"""Test replication with many state events over several stream ids."""
# we want to generate lots of state changes, but for this test, we want to
@@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
- def test_backwards_stream_id(self):
+ def test_backwards_stream_id(self) -> None:
"""
Test that RDATA that comes after the current position should be discarded.
"""
@@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
event_count = 0
def _inject_test_event(
- self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
+ self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
) -> EventBase:
if sender is None:
sender = self.user_id
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
index bcb82c9c80..cdbdfaf057 100644
--- a/tests/replication/tcp/streams/test_federation.py
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
config["federation_sender_instances"] = ["federation_sender1"]
return config
- def test_catchup(self):
+ def test_catchup(self) -> None:
"""Basic test of catchup on reconnect
Makes sure that updates sent while we are offline are received later.
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index 2c10eab4db..38b5020ce0 100644
--- a/tests/replication/tcp/streams/test_partial_state.py
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
hijack_auth = True
user_id = "@bob:test"
- def setUp(self):
+ def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 9a229dd23f..68de5d1cc2 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -27,10 +27,11 @@ ROOM_ID_2 = "!foo:blue"
class TypingStreamTestCase(BaseStreamTestCase):
- def _build_replication_data_handler(self):
- return Mock(wraps=super()._build_replication_data_handler())
+ def _build_replication_data_handler(self) -> Mock:
+ self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
+ return self.mock_handler
- def test_typing(self):
+ def test_typing(self) -> None:
typing = self.hs.get_typing_handler()
self.reconnect()
@@ -43,8 +44,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
- self.test_handler.on_rdata.assert_called_once()
- stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.mock_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -54,11 +55,11 @@ class TypingStreamTestCase(BaseStreamTestCase):
# Now let's disconnect and insert some data.
self.disconnect()
- self.test_handler.on_rdata.reset_mock()
+ self.mock_handler.on_rdata.reset_mock()
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
- self.test_handler.on_rdata.assert_not_called()
+ self.mock_handler.on_rdata.assert_not_called()
self.reconnect()
self.pump(0.1)
@@ -71,15 +72,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
assert request.args is not None
self.assertEqual(int(request.args[b"from_token"][0]), token)
- self.test_handler.on_rdata.assert_called_once()
- stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.mock_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([], row.user_ids)
- def test_reset(self):
+ def test_reset(self) -> None:
"""
Test what happens when a typing stream resets.
@@ -98,8 +99,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
- self.test_handler.on_rdata.assert_called_once()
- stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.mock_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -134,15 +135,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assert_request_is_get_repl_stream_updates(request, "typing")
# Reset the test code.
- self.test_handler.on_rdata.reset_mock()
- self.test_handler.on_rdata.assert_not_called()
+ self.mock_handler.on_rdata.reset_mock()
+ self.mock_handler.on_rdata.assert_not_called()
# Push additional data.
typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
self.reactor.advance(0)
- self.test_handler.on_rdata.assert_called_once()
- stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.mock_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index cca7ebb719..5d6b72b16d 100644
--- a/tests/replication/tcp/test_commands.py
+++ b/tests/replication/tcp/test_commands.py
@@ -21,12 +21,12 @@ from tests.unittest import TestCase
class ParseCommandTestCase(TestCase):
- def test_parse_one_word_command(self):
+ def test_parse_one_word_command(self) -> None:
line = "REPLICATE"
cmd = parse_command_from_line(line)
self.assertIsInstance(cmd, ReplicateCommand)
- def test_parse_rdata(self):
+ def test_parse_rdata(self) -> None:
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
cmd = parse_command_from_line(line)
assert isinstance(cmd, RdataCommand)
@@ -34,7 +34,7 @@ class ParseCommandTestCase(TestCase):
self.assertEqual(cmd.instance_name, "master")
self.assertEqual(cmd.token, 6287863)
- def test_parse_rdata_batch(self):
+ def test_parse_rdata_batch(self) -> None:
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
cmd = parse_command_from_line(line)
assert isinstance(cmd, RdataCommand)
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
index 545f11acd1..b75fc05fd5 100644
--- a/tests/replication/tcp/test_remote_server_up.py
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -16,15 +16,17 @@ from typing import Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IProtocol
-from twisted.test.proto_helpers import StringTransport
+from twisted.test.proto_helpers import MemoryReactor, StringTransport
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class RemoteServerUpTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.factory = ReplicationStreamProtocolFactory(hs)
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
@@ -40,7 +42,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
return proto, transport
- def test_relay(self):
+ def test_relay(self) -> None:
"""Test that Synapse will relay REMOTE_SERVER_UP commands to all
other connections, but not the one that sent it.
"""
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index 5d7a89e0c7..98602371e4 100644
--- a/tests/replication/test_auth.py
+++ b/tests/replication/test_auth.py
@@ -13,7 +13,11 @@
# limitations under the License.
import logging
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.rest.client import register
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, make_request
@@ -27,7 +31,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [register.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# This isn't a real configuration option but is used to provide the main
# homeserver and worker homeserver different options.
@@ -77,7 +81,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
{"auth": {"session": session, "type": "m.login.dummy"}},
)
- def test_no_auth(self):
+ def test_no_auth(self) -> None:
"""With no authentication the request should finish."""
channel = self._test_register()
self.assertEqual(channel.code, 200)
@@ -86,7 +90,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel.json_body["user_id"], "@user:test")
@override_config({"main_replication_secret": "my-secret"})
- def test_missing_auth(self):
+ def test_missing_auth(self) -> None:
"""If the main process expects a secret that is not provided, an error results."""
channel = self._test_register()
self.assertEqual(channel.code, 500)
@@ -97,13 +101,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
"worker_replication_secret": "wrong-secret",
}
)
- def test_unauthorized(self):
+ def test_unauthorized(self) -> None:
"""If the main process receives the wrong secret, an error results."""
channel = self._test_register()
self.assertEqual(channel.code, 500)
@override_config({"worker_replication_secret": "my-secret"})
- def test_authorized(self):
+ def test_authorized(self) -> None:
"""The request should finish when the worker provides the authentication header."""
channel = self._test_register()
self.assertEqual(channel.code, 200)
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index eb5b376534..eca5033761 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -33,7 +33,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
config["worker_replication_http_port"] = "8765"
return config
- def test_register_single_worker(self):
+ def test_register_single_worker(self) -> None:
"""Test that registration works when using a single generic worker."""
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
site = self._hs_to_site[worker_hs]
@@ -63,7 +63,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
# We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
- def test_register_multi_worker(self):
+ def test_register_multi_worker(self) -> None:
"""Test that registration works when using multiple generic workers."""
worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 63b1dd40b5..12668b34c5 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -14,10 +14,14 @@
from unittest import mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -30,12 +34,10 @@ class FederationAckTestCase(HomeserverTestCase):
config["federation_sender_instances"] = ["federation_sender1"]
return config
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
-
- return hs
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ return self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
- def test_federation_ack_sent(self):
+ def test_federation_ack_sent(self) -> None:
"""A FEDERATION_ACK should be sent back after each RDATA federation
This test checks that the federation sender is correctly sending back
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index c28073b8f7..89380e25b5 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -40,7 +40,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
room.register_servlets,
]
- def test_send_event_single_sender(self):
+ def test_send_event_single_sender(self) -> None:
"""Test that using a single federation sender worker correctly sends a
new event.
"""
@@ -71,7 +71,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
- def test_send_event_sharded(self):
+ def test_send_event_sharded(self) -> None:
"""Test that using two federation sender workers correctly sends
new events.
"""
@@ -138,7 +138,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2)
- def test_send_typing_sharded(self):
+ def test_send_typing_sharded(self) -> None:
"""Test that using two federation sender workers correctly sends
new typing EDUs.
"""
@@ -215,7 +215,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2)
- def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+ def create_room_with_remote_server(
+ self, user: str, token: str, remote_server: str = "other_server"
+ ) -> str:
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastores().main
federation = self.hs.get_federation_event_handler()
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
index b93cae67d3..9c4fbda71b 100644
--- a/tests/replication/test_module_cache_invalidation.py
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -39,7 +39,7 @@ class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
synapse.rest.admin.register_servlets,
]
- def test_module_cache_full_invalidation(self):
+ def test_module_cache_full_invalidation(self) -> None:
main_cache = TestCache()
self.hs.get_module_api().register_cached_function(main_cache.cached_function)
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 96cdf2c45b..1527b4a82d 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -18,12 +18,14 @@ from typing import Optional, Tuple
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
+from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http import HTTPChannel
from twisted.web.server import Request
from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer
+from synapse.util import Clock
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
self.reactor.lookups["example.com"] = "1.2.3.4"
- def default_config(self):
+ def default_config(self) -> dict:
conf = super().default_config()
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
return conf
@@ -122,7 +124,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return channel, request
- def test_basic(self):
+ def test_basic(self) -> None:
"""Test basic fetching of remote media from a single worker."""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
@@ -138,7 +140,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"Hello!")
- def test_download_simple_file_race(self):
+ def test_download_simple_file_race(self) -> None:
"""Test that fetching remote media from two different processes at the
same time works.
"""
@@ -177,7 +179,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
# We expect only one new file to have been persisted.
self.assertEqual(start_count + 1, self._count_remote_media())
- def test_download_image_race(self):
+ def test_download_image_race(self) -> None:
"""Test that fetching remote *images* from two different processes at
the same time works.
@@ -229,7 +231,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path))
-def get_connection_factory():
+def get_connection_factory() -> TestServerTLSConnectionFactory:
# this needs to happen once, but not until we are ready to run the first test
global test_server_connection_factory
if test_server_connection_factory is None:
@@ -263,6 +265,6 @@ def _build_test_server(
return server_tls_factory.buildProtocol(None)
-def _log_request(request):
+def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index ca18ad6553..9345cfbeb2 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -15,9 +15,12 @@ import logging
from unittest.mock import Mock
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -33,12 +36,12 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
- def _create_pusher_and_send_msg(self, localpart):
+ def _create_pusher_and_send_msg(self, localpart: str) -> str:
# Create a user that will get push notifications
user_id = self.register_user(localpart, "pass")
access_token = self.login(localpart, "pass")
@@ -79,7 +82,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
return event_id
- def test_send_push_single_worker(self):
+ def test_send_push_single_worker(self) -> None:
"""Test that registration works when using a pusher worker."""
http_client_mock = Mock(spec_set=["post_json_get_json"])
http_client_mock.post_json_get_json.side_effect = (
@@ -109,7 +112,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
],
)
- def test_send_push_multiple_workers(self):
+ def test_send_push_multiple_workers(self) -> None:
"""Test that registration works when using sharded pusher workers."""
http_client_mock1 = Mock(spec_set=["post_json_get_json"])
http_client_mock1.post_json_get_json.side_effect = (
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 541d390286..7f9cc67e73 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -14,9 +14,13 @@
import logging
from unittest.mock import patch
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.rest import admin
from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
@@ -34,7 +38,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
sync.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
@@ -42,7 +46,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.room_creator = self.hs.get_room_creation_handler()
self.store = hs.get_datastores().main
- def default_config(self):
+ def default_config(self) -> dict:
conf = super().default_config()
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = {
@@ -51,7 +55,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
}
return conf
- def _create_room(self, room_id: str, user_id: str, tok: str):
+ def _create_room(self, room_id: str, user_id: str, tok: str) -> None:
"""Create a room with given room_id"""
# We control the room ID generation by patching out the
@@ -62,7 +66,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
mock.side_effect = lambda: room_id
self.helper.create_room_as(user_id, tok=tok)
- def test_basic(self):
+ def test_basic(self) -> None:
"""Simple test to ensure that multiple rooms can be created and joined,
and that different rooms get handled by different instances.
"""
@@ -112,7 +116,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(persisted_on_1)
self.assertTrue(persisted_on_2)
- def test_vector_clock_token(self):
+ def test_vector_clock_token(self) -> None:
"""Tests that using a stream token with a vector clock component works
correctly with basic /sync and /messages usage.
"""
|