summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorJonathan de Jong <jonathan@automatia.nl>2021-07-13 12:52:58 +0200
committerGitHub <noreply@github.com>2021-07-13 11:52:58 +0100
commit93729719b8451493e1df9930feb9f02f14ea5cef (patch)
tree90f9608894f30f35c824427aa8b3657a41246bdd /tests
parent[pyupgrade] `tests/` (#10347) (diff)
downloadsynapse-93729719b8451493e1df9930feb9f02f14ea5cef.tar.xz
Use inline type hints in `tests/` (#10350)
This PR is tantamount to running:

    python3.8 -m com2ann -v 6 tests/

(com2ann requires python 3.8 to run)
Diffstat (limited to 'tests')
-rw-r--r--tests/events/test_presence_router.py6
-rw-r--r--tests/module_api/test_api.py16
-rw-r--r--tests/replication/_base.py12
-rw-r--r--tests/replication/tcp/streams/test_events.py14
-rw-r--r--tests/replication/tcp/streams/test_receipts.py4
-rw-r--r--tests/replication/tcp/streams/test_typing.py4
-rw-r--r--tests/replication/test_multi_media_repo.py2
-rw-r--r--tests/rest/client/test_third_party_rules.py4
-rw-r--r--tests/rest/client/v1/test_login.py14
-rw-r--r--tests/server.py8
-rw-r--r--tests/storage/test_background_update.py4
-rw-r--r--tests/storage/test_id_generators.py6
-rw-r--r--tests/test_state.py2
-rw-r--r--tests/test_utils/html_parsers.py6
-rw-r--r--tests/unittest.py2
-rw-r--r--tests/util/caches/test_descriptors.py2
-rw-r--r--tests/util/test_itertools.py18
17 files changed, 61 insertions, 63 deletions
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 875b0d0a11..c4ad33194d 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -152,7 +152,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         )
         self.assertEqual(len(presence_updates), 1)
 
-        presence_update = presence_updates[0]  # type: UserPresenceState
+        presence_update: UserPresenceState = presence_updates[0]
         self.assertEqual(presence_update.user_id, self.other_user_one_id)
         self.assertEqual(presence_update.state, "online")
         self.assertEqual(presence_update.status_msg, "boop")
@@ -274,7 +274,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         presence_updates, _ = sync_presence(self, self.other_user_id)
         self.assertEqual(len(presence_updates), 1)
 
-        presence_update = presence_updates[0]  # type: UserPresenceState
+        presence_update: UserPresenceState = presence_updates[0]
         self.assertEqual(presence_update.user_id, self.other_user_id)
         self.assertEqual(presence_update.state, "online")
         self.assertEqual(presence_update.status_msg, "I'm online!")
@@ -320,7 +320,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         )
         for call in calls:
             call_args = call[0]
-            federation_transaction = call_args[0]  # type: Transaction
+            federation_transaction: Transaction = call_args[0]
 
             # Get the sent EDUs in this transaction
             edus = federation_transaction.get_dict()["edus"]
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 2c68b9a13c..81d9e2f484 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -100,9 +100,9 @@ class ModuleApiTestCase(HomeserverTestCase):
             "content": content,
             "sender": user_id,
         }
-        event = self.get_success(
+        event: EventBase = self.get_success(
             self.module_api.create_and_send_event_into_room(event_dict)
-        )  # type: EventBase
+        )
         self.assertEqual(event.sender, user_id)
         self.assertEqual(event.type, "m.room.message")
         self.assertEqual(event.room_id, room_id)
@@ -136,9 +136,9 @@ class ModuleApiTestCase(HomeserverTestCase):
             "sender": user_id,
             "state_key": "",
         }
-        event = self.get_success(
+        event: EventBase = self.get_success(
             self.module_api.create_and_send_event_into_room(event_dict)
-        )  # type: EventBase
+        )
         self.assertEqual(event.sender, user_id)
         self.assertEqual(event.type, "m.room.power_levels")
         self.assertEqual(event.room_id, room_id)
@@ -281,7 +281,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         )
         for call in calls:
             call_args = call[0]
-            federation_transaction = call_args[0]  # type: Transaction
+            federation_transaction: Transaction = call_args[0]
 
             # Get the sent EDUs in this transaction
             edus = federation_transaction.get_dict()["edus"]
@@ -390,7 +390,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update = presence_updates[0]  # type: UserPresenceState
+    presence_update: UserPresenceState = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
@@ -443,7 +443,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update = presence_updates[0]  # type: UserPresenceState
+    presence_update: UserPresenceState = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
@@ -454,7 +454,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update = presence_updates[0]  # type: UserPresenceState
+    presence_update: UserPresenceState = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 386ea70a25..e9fd991718 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -53,9 +53,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
         self.streamer = hs.get_replication_streamer()
-        self.server = server_factory.buildProtocol(
+        self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
             None
-        )  # type: ServerReplicationStreamProtocol
+        )
 
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -195,7 +195,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         fetching updates for given stream.
         """
 
-        path = request.path  # type: bytes  # type: ignore
+        path: bytes = request.path  # type: ignore
         self.assertRegex(
             path,
             br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
@@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
     unlike `BaseStreamTestCase`.
     """
 
-    servlets = []  # type: List[Callable[[HomeServer, JsonResource], None]]
+    servlets: List[Callable[[HomeServer, JsonResource], None]] = []
 
     def setUp(self):
         super().setUp()
@@ -448,7 +448,7 @@ class TestReplicationDataHandler(ReplicationDataHandler):
         super().__init__(hs)
 
         # list of received (stream_name, token, row) tuples
-        self.received_rdata_rows = []  # type: List[Tuple[str, int, Any]]
+        self.received_rdata_rows: List[Tuple[str, int, Any]] = []
 
     async def on_rdata(self, stream_name, instance_name, token, rows):
         await super().on_rdata(stream_name, instance_name, token, rows)
@@ -484,7 +484,7 @@ class FakeRedisPubSubServer:
 class FakeRedisPubSubProtocol(Protocol):
     """A connection from a client talking to the fake Redis server."""
 
-    transport = None  # type: Optional[FakeTransport]
+    transport: Optional[FakeTransport] = None
 
     def __init__(self, server: FakeRedisPubSubServer):
         self._server = server
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index f51fa0a79e..666008425a 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -135,9 +135,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point = self.get_success(
+        fork_point: List[str] = self.get_success(
             self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
-        )  # type: List[str]
+        )
 
         events = [
             self._inject_state_event(sender=OTHER_USER)
@@ -238,7 +238,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.assertEqual(row.data.event_id, pl_event.event_id)
 
         # the state rows are unsorted
-        state_rows = []  # type: List[EventsStreamCurrentStateRow]
+        state_rows: List[EventsStreamCurrentStateRow] = []
         for stream_name, _, row in received_rows:
             self.assertEqual("events", stream_name)
             self.assertIsInstance(row, EventsStreamRow)
@@ -290,11 +290,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point = self.get_success(
+        fork_point: List[str] = self.get_success(
             self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
-        )  # type: List[str]
+        )
 
-        events = []  # type: List[EventBase]
+        events: List[EventBase] = []
         for user in user_ids:
             events.extend(
                 self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
@@ -355,7 +355,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             self.assertEqual(row.data.event_id, pl_events[i].event_id)
 
             # the state rows are unsorted
-            state_rows = []  # type: List[EventsStreamCurrentStateRow]
+            state_rows: List[EventsStreamCurrentStateRow] = []
             for _ in range(STATES_PER_USER + 1):
                 stream_name, token, row = received_rows.pop(0)
                 self.assertEqual("events", stream_name)
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 7f5d932f0b..38e292c1ab 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -43,7 +43,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
         stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "receipts")
         self.assertEqual(1, len(rdata_rows))
-        row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow
+        row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
         self.assertEqual("!room:blue", row.room_id)
         self.assertEqual("m.read", row.receipt_type)
         self.assertEqual(USER_ID, row.user_id)
@@ -75,7 +75,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
         self.assertEqual(token, 3)
         self.assertEqual(1, len(rdata_rows))
 
-        row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow
+        row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
         self.assertEqual("!room2:blue", row.room_id)
         self.assertEqual("m.read", row.receipt_type)
         self.assertEqual(USER_ID, row.user_id)
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index ecd360c2d0..3ff5afc6e5 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
-        row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
+        row: TypingStream.TypingStreamRow = rdata_rows[0]
         self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([USER_ID], row.user_ids)
 
@@ -102,7 +102,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
-        row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
+        row: TypingStream.TypingStreamRow = rdata_rows[0]
         self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([USER_ID], row.user_ids)
 
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index b42f1288eb..ffa425328f 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -31,7 +31,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
 
 logger = logging.getLogger(__name__)
 
-test_server_connection_factory = None  # type: Optional[TestServerTLSConnectionFactory]
+test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None
 
 
 class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e1fe72fc5d..c5e1c5458b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -233,11 +233,11 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             "content": content,
             "sender": self.user_id,
         }
-        event = self.get_success(
+        event: EventBase = self.get_success(
             current_rules_module().module_api.create_and_send_event_into_room(
                 event_dict
             )
-        )  # type: EventBase
+        )
 
         self.assertEquals(event.sender, self.user_id)
         self.assertEquals(event.room_id, self.room_id)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 605b952316..7eba69642a 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -453,7 +453,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
         # stick the flows results in a dict by type
-        flow_results = {}  # type: Dict[str, Any]
+        flow_results: Dict[str, Any] = {}
         for f in channel.json_body["flows"]:
             flow_type = f["type"]
             self.assertNotIn(
@@ -501,7 +501,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         p.close()
 
         # there should be a link for each href
-        returned_idps = []  # type: List[str]
+        returned_idps: List[str] = []
         for link in p.links:
             path, query = link.split("?", 1)
             self.assertEqual(path, "pick_idp")
@@ -582,7 +582,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         # ... and should have set a cookie including the redirect url
         cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
         assert cookie_headers
-        cookies = {}  # type: Dict[str, str]
+        cookies: Dict[str, str] = {}
         for h in cookie_headers:
             key, value = h.split(";")[0].split("=", maxsplit=1)
             cookies[key] = value
@@ -874,9 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result = jwt.encode(
-            payload, secret, self.jwt_algorithm
-        )  # type: Union[str, bytes]
+        result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
         if isinstance(result, bytes):
             return result.decode("ascii")
         return result
@@ -1084,7 +1082,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result = jwt.encode(payload, secret, "RS256")  # type: Union[bytes,str]
+        result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
         if isinstance(result, bytes):
             return result.decode("ascii")
         return result
@@ -1272,7 +1270,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
 
         # ... with a username_mapping_session cookie
-        cookies = {}  # type: Dict[str,str]
+        cookies: Dict[str, str] = {}
         channel.extract_cookies(cookies)
         self.assertIn("username_mapping_session", cookies)
         session_id = cookies["username_mapping_session"]
diff --git a/tests/server.py b/tests/server.py
index f32d8dc375..6fddd3b305 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -52,7 +52,7 @@ class FakeChannel:
     _reactor = attr.ib()
     result = attr.ib(type=dict, default=attr.Factory(dict))
     _ip = attr.ib(type=str, default="127.0.0.1")
-    _producer = None  # type: Optional[Union[IPullProducer, IPushProducer]]
+    _producer: Optional[Union[IPullProducer, IPushProducer]] = None
 
     @property
     def json_body(self):
@@ -316,8 +316,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 
         self._tcp_callbacks = {}
         self._udp = []
-        lookups = self.lookups = {}  # type: Dict[str, str]
-        self._thread_callbacks = deque()  # type: Deque[Callable[[], None]]
+        self.lookups: Dict[str, str] = {}
+        self._thread_callbacks: Deque[Callable[[], None]] = deque()
+
+        lookups = self.lookups
 
         @implementer(IResolverSimple)
         class FakeResolver:
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 069db0edc4..0da42b5ac5 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -7,9 +7,7 @@ from tests import unittest
 
 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
-        self.updates = (
-            self.hs.get_datastore().db_pool.updates
-        )  # type: BackgroundUpdater
+        self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
         # the base test class should have run the real bg updates for us
         self.assertTrue(
             self.get_success(self.updates.has_completed_background_updates())
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 792b1c44c1..7486078284 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -27,7 +27,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.db_pool = self.store.db_pool  # type: DatabasePool
+        self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
@@ -460,7 +460,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.db_pool = self.store.db_pool  # type: DatabasePool
+        self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
@@ -586,7 +586,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.db_pool = self.store.db_pool  # type: DatabasePool
+        self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
diff --git a/tests/test_state.py b/tests/test_state.py
index 62f7095873..780eba823c 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase):
 
         self.store.register_events(graph.walk())
 
-        context_store = {}  # type: dict[str, EventContext]
+        context_store: dict[str, EventContext] = {}
 
         for event in graph.walk():
             context = yield defer.ensureDeferred(
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index 1fbb38f4be..e878af5f12 100644
--- a/tests/test_utils/html_parsers.py
+++ b/tests/test_utils/html_parsers.py
@@ -23,13 +23,13 @@ class TestHtmlParser(HTMLParser):
         super().__init__()
 
         # a list of links found in the doc
-        self.links = []  # type: List[str]
+        self.links: List[str] = []
 
         # the values of any hidden <input>s: map from name to value
-        self.hiddens = {}  # type: Dict[str, Optional[str]]
+        self.hiddens: Dict[str, Optional[str]] = {}
 
         # the values of any radio buttons: map from name to list of values
-        self.radios = {}  # type: Dict[str, List[Optional[str]]]
+        self.radios: Dict[str, List[Optional[str]]] = {}
 
     def handle_starttag(
         self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
diff --git a/tests/unittest.py b/tests/unittest.py
index 907b94b10a..c6d9064423 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -520,7 +520,7 @@ class HomeserverTestCase(TestCase):
         if not isinstance(deferred, Deferred):
             return d
 
-        results = []  # type: list
+        results: list = []
         deferred.addBoth(results.append)
 
         self.pump(by=by)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 0277998cbe..39947a166b 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -174,7 +174,7 @@ class DescriptorTestCase(unittest.TestCase):
                 return self.result
 
         obj = Cls()
-        callbacks = set()  # type: Set[str]
+        callbacks: Set[str] = set()
 
         # set off an asynchronous request
         obj.result = origin_d = defer.Deferred()
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index e712eb42ea..3c0ddd4f18 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase):
         )
 
     def test_empty_input(self):
-        parts = chunk_seq([], 5)  # type: Iterable[Sequence]
+        parts: Iterable[Sequence] = chunk_seq([], 5)
 
         self.assertEqual(
             list(parts),
@@ -56,13 +56,13 @@ class SortTopologically(TestCase):
     def test_empty(self):
         "Test that an empty graph works correctly"
 
-        graph = {}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {}
         self.assertEqual(list(sorted_topologically([], graph)), [])
 
     def test_handle_empty_graph(self):
         "Test that a graph where a node doesn't have an entry is treated as empty"
 
-        graph = {}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {}
 
         # For disconnected nodes the output is simply sorted.
         self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
@@ -70,7 +70,7 @@ class SortTopologically(TestCase):
     def test_disconnected(self):
         "Test that a graph with no edges work"
 
-        graph = {1: [], 2: []}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: []}
 
         # For disconnected nodes the output is simply sorted.
         self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
@@ -78,19 +78,19 @@ class SortTopologically(TestCase):
     def test_linear(self):
         "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
 
-        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
 
         self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
 
     def test_subset(self):
         "Test that only sorting a subset of the graph works"
-        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
 
         self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
 
     def test_fork(self):
         "Test that a forked graph works"
-        graph = {1: [], 2: [1], 3: [1], 4: [2, 3]}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
 
         # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
         # always get the same one.
@@ -98,12 +98,12 @@ class SortTopologically(TestCase):
 
     def test_duplicates(self):
         "Test that a graph with duplicate edges work"
-        graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
 
         self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
 
     def test_multiple_paths(self):
         "Test that a graph with multiple paths between two nodes work"
-        graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}
 
         self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])