diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 386ea70a25..e9fd991718 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -53,9 +53,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
- self.server = server_factory.buildProtocol(
+ self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
None
- ) # type: ServerReplicationStreamProtocol
+ )
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -195,7 +195,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
fetching updates for given stream.
"""
- path = request.path # type: bytes # type: ignore
+ path: bytes = request.path # type: ignore
self.assertRegex(
path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
@@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`.
"""
- servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
+ servlets: List[Callable[[HomeServer, JsonResource], None]] = []
def setUp(self):
super().setUp()
@@ -448,7 +448,7 @@ class TestReplicationDataHandler(ReplicationDataHandler):
super().__init__(hs)
# list of received (stream_name, token, row) tuples
- self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
+ self.received_rdata_rows: List[Tuple[str, int, Any]] = []
async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
@@ -484,7 +484,7 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""
- transport = None # type: Optional[FakeTransport]
+ transport: Optional[FakeTransport] = None
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index f51fa0a79e..666008425a 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -135,9 +135,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point = self.get_success(
+ fork_point: List[str] = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
- ) # type: List[str]
+ )
events = [
self._inject_state_event(sender=OTHER_USER)
@@ -238,7 +238,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual(row.data.event_id, pl_event.event_id)
# the state rows are unsorted
- state_rows = [] # type: List[EventsStreamCurrentStateRow]
+ state_rows: List[EventsStreamCurrentStateRow] = []
for stream_name, _, row in received_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
@@ -290,11 +290,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point = self.get_success(
+ fork_point: List[str] = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
- ) # type: List[str]
+ )
- events = [] # type: List[EventBase]
+ events: List[EventBase] = []
for user in user_ids:
events.extend(
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
@@ -355,7 +355,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual(row.data.event_id, pl_events[i].event_id)
# the state rows are unsorted
- state_rows = [] # type: List[EventsStreamCurrentStateRow]
+ state_rows: List[EventsStreamCurrentStateRow] = []
for _ in range(STATES_PER_USER + 1):
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 7f5d932f0b..38e292c1ab 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -43,7 +43,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
- row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
@@ -75,7 +75,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows))
- row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
self.assertEqual("!room2:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index ecd360c2d0..3ff5afc6e5 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
- row = rdata_rows[0] # type: TypingStream.TypingStreamRow
+ row: TypingStream.TypingStreamRow = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
@@ -102,7 +102,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
- row = rdata_rows[0] # type: TypingStream.TypingStreamRow
+ row: TypingStream.TypingStreamRow = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index b42f1288eb..ffa425328f 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -31,7 +31,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
logger = logging.getLogger(__name__)
-test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory]
+test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None
class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|