summary refs log tree commit diff
path: root/tests/replication
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication')
-rw-r--r--tests/replication/_base.py18
-rw-r--r--tests/replication/tcp/streams/test_events.py14
-rw-r--r--tests/replication/tcp/streams/test_receipts.py4
-rw-r--r--tests/replication/tcp/streams/test_typing.py4
-rw-r--r--tests/replication/test_multi_media_repo.py6
-rw-r--r--tests/replication/test_sharded_event_persister.py6
6 files changed, 26 insertions, 26 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 624bd1b927..e9fd991718 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -53,9 +53,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
         self.streamer = hs.get_replication_streamer()
-        self.server = server_factory.buildProtocol(
+        self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
             None
-        )  # type: ServerReplicationStreamProtocol
+        )
 
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -195,7 +195,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         fetching updates for given stream.
         """
 
-        path = request.path  # type: bytes  # type: ignore
+        path: bytes = request.path  # type: ignore
         self.assertRegex(
             path,
             br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
@@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
     unlike `BaseStreamTestCase`.
     """
 
-    servlets = []  # type: List[Callable[[HomeServer, JsonResource], None]]
+    servlets: List[Callable[[HomeServer, JsonResource], None]] = []
 
     def setUp(self):
         super().setUp()
@@ -448,7 +448,7 @@ class TestReplicationDataHandler(ReplicationDataHandler):
         super().__init__(hs)
 
         # list of received (stream_name, token, row) tuples
-        self.received_rdata_rows = []  # type: List[Tuple[str, int, Any]]
+        self.received_rdata_rows: List[Tuple[str, int, Any]] = []
 
     async def on_rdata(self, stream_name, instance_name, token, rows):
         await super().on_rdata(stream_name, instance_name, token, rows)
@@ -484,7 +484,7 @@ class FakeRedisPubSubServer:
 class FakeRedisPubSubProtocol(Protocol):
     """A connection from a client talking to the fake Redis server."""
 
-    transport = None  # type: Optional[FakeTransport]
+    transport: Optional[FakeTransport] = None
 
     def __init__(self, server: FakeRedisPubSubServer):
         self._server = server
@@ -550,12 +550,12 @@ class FakeRedisPubSubProtocol(Protocol):
         if obj is None:
             return "$-1\r\n"
         if isinstance(obj, str):
-            return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
+            return f"${len(obj)}\r\n{obj}\r\n"
         if isinstance(obj, int):
-            return ":{val}\r\n".format(val=obj)
+            return f":{obj}\r\n"
         if isinstance(obj, (list, tuple)):
             items = "".join(self.encode(a) for a in obj)
-            return "*{len}\r\n{items}".format(len=len(obj), items=items)
+            return f"*{len(obj)}\r\n{items}"
 
         raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
 
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index f51fa0a79e..666008425a 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -135,9 +135,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point = self.get_success(
+        fork_point: List[str] = self.get_success(
             self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
-        )  # type: List[str]
+        )
 
         events = [
             self._inject_state_event(sender=OTHER_USER)
@@ -238,7 +238,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.assertEqual(row.data.event_id, pl_event.event_id)
 
         # the state rows are unsorted
-        state_rows = []  # type: List[EventsStreamCurrentStateRow]
+        state_rows: List[EventsStreamCurrentStateRow] = []
         for stream_name, _, row in received_rows:
             self.assertEqual("events", stream_name)
             self.assertIsInstance(row, EventsStreamRow)
@@ -290,11 +290,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point = self.get_success(
+        fork_point: List[str] = self.get_success(
             self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
-        )  # type: List[str]
+        )
 
-        events = []  # type: List[EventBase]
+        events: List[EventBase] = []
         for user in user_ids:
             events.extend(
                 self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
@@ -355,7 +355,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             self.assertEqual(row.data.event_id, pl_events[i].event_id)
 
             # the state rows are unsorted
-            state_rows = []  # type: List[EventsStreamCurrentStateRow]
+            state_rows: List[EventsStreamCurrentStateRow] = []
             for _ in range(STATES_PER_USER + 1):
                 stream_name, token, row = received_rows.pop(0)
                 self.assertEqual("events", stream_name)
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 7f5d932f0b..38e292c1ab 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -43,7 +43,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
         stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "receipts")
         self.assertEqual(1, len(rdata_rows))
-        row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow
+        row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
         self.assertEqual("!room:blue", row.room_id)
         self.assertEqual("m.read", row.receipt_type)
         self.assertEqual(USER_ID, row.user_id)
@@ -75,7 +75,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
         self.assertEqual(token, 3)
         self.assertEqual(1, len(rdata_rows))
 
-        row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow
+        row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
         self.assertEqual("!room2:blue", row.room_id)
         self.assertEqual("m.read", row.receipt_type)
         self.assertEqual(USER_ID, row.user_id)
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index ecd360c2d0..3ff5afc6e5 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
-        row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
+        row: TypingStream.TypingStreamRow = rdata_rows[0]
         self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([USER_ID], row.user_ids)
 
@@ -102,7 +102,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
-        row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
+        row: TypingStream.TypingStreamRow = rdata_rows[0]
         self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([USER_ID], row.user_ids)
 
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 76e6644353..ffa425328f 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -31,7 +31,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
 
 logger = logging.getLogger(__name__)
 
-test_server_connection_factory = None  # type: Optional[TestServerTLSConnectionFactory]
+test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None
 
 
 class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
@@ -70,7 +70,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
             self.reactor,
             FakeSite(resource),
             "GET",
-            "/{}/{}".format(target, media_id),
+            f"/{target}/{media_id}",
             shorthand=False,
             access_token=self.access_token,
             await_result=False,
@@ -113,7 +113,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(request.method, b"GET")
         self.assertEqual(
             request.path,
-            "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"),
+            f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"),
         )
         self.assertEqual(
             request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 5eca5c165d..f3615af97e 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -211,7 +211,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
             self.reactor,
             sync_hs_site,
             "GET",
-            "/sync?since={}".format(next_batch),
+            f"/sync?since={next_batch}",
             access_token=access_token,
         )
 
@@ -241,7 +241,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
             self.reactor,
             sync_hs_site,
             "GET",
-            "/sync?since={}".format(vector_clock_token),
+            f"/sync?since={vector_clock_token}",
             access_token=access_token,
         )
 
@@ -266,7 +266,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
             self.reactor,
             sync_hs_site,
             "GET",
-            "/sync?since={}".format(next_batch),
+            f"/sync?since={next_batch}",
             access_token=access_token,
         )