summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/http/federation/well_known_resolver.py13
-rw-r--r--synapse/storage/background_updates.py16
-rw-r--r--synapse/storage/database.py14
-rw-r--r--synapse/storage/databases/main/appservice.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py2
-rw-r--r--synapse/storage/databases/main/event_federation.py26
-rw-r--r--synapse/storage/databases/main/event_push_actions.py2
-rw-r--r--synapse/storage/databases/main/events.py38
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py8
-rw-r--r--synapse/storage/databases/main/events_worker.py6
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py6
-rw-r--r--synapse/storage/databases/main/registration.py2
-rw-r--r--synapse/storage/databases/main/stream.py6
-rw-r--r--synapse/storage/databases/main/tags.py2
-rw-r--r--synapse/storage/databases/main/ui_auth.py4
-rw-r--r--synapse/storage/persist_events.py16
-rw-r--r--synapse/storage/prepare_database.py6
-rw-r--r--synapse/storage/state.py4
-rw-r--r--synapse/storage/util/id_generators.py12
-rw-r--r--synapse/storage/util/sequence.py6
-rw-r--r--synapse/util/async_helpers.py8
-rw-r--r--synapse/util/batching_queue.py8
-rw-r--r--synapse/util/caches/__init__.py4
-rw-r--r--synapse/util/caches/cached_call.py6
-rw-r--r--synapse/util/caches/deferred_cache.py12
-rw-r--r--synapse/util/caches/descriptors.py36
-rw-r--r--synapse/util/caches/dictionary_cache.py6
-rw-r--r--synapse/util/caches/expiringcache.py4
-rw-r--r--synapse/util/caches/lrucache.py8
-rw-r--r--synapse/util/caches/response_cache.py2
-rw-r--r--synapse/util/caches/stream_change_cache.py6
-rw-r--r--synapse/util/caches/ttlcache.py6
-rw-r--r--synapse/util/iterutils.py2
-rw-r--r--synapse/util/macaroons.py2
-rw-r--r--synapse/util/metrics.py2
-rw-r--r--synapse/util/patch_inline_callbacks.py4
37 files changed, 149 insertions, 162 deletions
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 20d39a4ea6..43f2140429 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -70,10 +70,8 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
 logger = logging.getLogger(__name__)
 
 
-_well_known_cache = TTLCache("well-known")  # type: TTLCache[bytes, Optional[bytes]]
-_had_valid_well_known_cache = TTLCache(
-    "had-valid-well-known"
-)  # type: TTLCache[bytes, bool]
+_well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
+_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
 
 
 @attr.s(slots=True, frozen=True)
@@ -130,9 +128,10 @@ class WellKnownResolver:
         # requests for the same server in parallel?
         try:
             with Measure(self._clock, "get_well_known"):
-                result, cache_period = await self._fetch_well_known(
-                    server_name
-                )  # type: Optional[bytes], float
+                result: Optional[bytes]
+                cache_period: float
+
+                result, cache_period = await self._fetch_well_known(server_name)
 
         except _FetchWellKnownFailure as e:
             if prev_result and e.temporary:
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 142787fdfd..82b31d24f1 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -92,14 +92,12 @@ class BackgroundUpdater:
         self.db_pool = database
 
         # if a background update is currently running, its name.
-        self._current_background_update = None  # type: Optional[str]
-
-        self._background_update_performance = (
-            {}
-        )  # type: Dict[str, BackgroundUpdatePerformance]
-        self._background_update_handlers = (
-            {}
-        )  # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
+        self._current_background_update: Optional[str] = None
+
+        self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
+        self._background_update_handlers: Dict[
+            str, Callable[[JsonDict, int], Awaitable[int]]
+        ] = {}
         self._all_done = False
 
     def start_doing_background_updates(self) -> None:
@@ -411,7 +409,7 @@ class BackgroundUpdater:
             c.execute(sql)
 
         if isinstance(self.db_pool.engine, engines.PostgresEngine):
-            runner = create_index_psql  # type: Optional[Callable[[Connection], None]]
+            runner: Optional[Callable[[Connection], None]] = create_index_psql
         elif psql_only:
             runner = None
         else:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 33c42cf95a..f80d822c12 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -670,8 +670,8 @@ class DatabasePool:
         Returns:
             The result of func
         """
-        after_callbacks = []  # type: List[_CallbackListEntry]
-        exception_callbacks = []  # type: List[_CallbackListEntry]
+        after_callbacks: List[_CallbackListEntry] = []
+        exception_callbacks: List[_CallbackListEntry] = []
 
         if not current_context():
             logger.warning("Starting db txn '%s' from sentinel context", desc)
@@ -1090,7 +1090,7 @@ class DatabasePool:
                 return False
 
         # We didn't find any existing rows, so insert a new one
-        allvalues = {}  # type: Dict[str, Any]
+        allvalues: Dict[str, Any] = {}
         allvalues.update(keyvalues)
         allvalues.update(values)
         allvalues.update(insertion_values)
@@ -1121,7 +1121,7 @@ class DatabasePool:
             values: The nonunique columns and their new values
             insertion_values: additional key/values to use only when inserting
         """
-        allvalues = {}  # type: Dict[str, Any]
+        allvalues: Dict[str, Any] = {}
         allvalues.update(keyvalues)
         allvalues.update(insertion_values or {})
 
@@ -1257,7 +1257,7 @@ class DatabasePool:
             value_values: A list of each row's value column values.
                 Ignored if value_names is empty.
         """
-        allnames = []  # type: List[str]
+        allnames: List[str] = []
         allnames.extend(key_names)
         allnames.extend(value_names)
 
@@ -1566,7 +1566,7 @@ class DatabasePool:
         """
         keyvalues = keyvalues or {}
 
-        results = []  # type: List[Dict[str, Any]]
+        results: List[Dict[str, Any]] = []
 
         if not iterable:
             return results
@@ -1978,7 +1978,7 @@ class DatabasePool:
             raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
 
         where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
-        arg_list = []  # type: List[Any]
+        arg_list: List[Any] = []
         if filters:
             where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
             arg_list += list(filters.values())
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 9f182c2a89..e2d1b758bd 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -48,9 +48,7 @@ def _make_exclusive_regex(
     ]
     if exclusive_user_regexes:
         exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
-        exclusive_user_pattern = re.compile(
-            exclusive_user_regex
-        )  # type: Optional[Pattern]
+        exclusive_user_pattern: Optional[Pattern] = re.compile(exclusive_user_regex)
     else:
         # We handle this case specially otherwise the constructed regex
         # will always match
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 0e3dd4e9ca..78ae68ec68 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -247,7 +247,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         txn.execute(sql, query_params)
 
-        result = {}  # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+        result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
         for (user_id, device_id, display_name, key_json) in txn:
             if include_deleted_devices:
                 deleted_devices.remove((user_id, device_id))
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4e06938849..d39368c20e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -62,9 +62,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             )
 
         # Cache of event ID to list of auth event IDs and their depths.
-        self._event_auth_cache = LruCache(
+        self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache(
             500000, "_event_auth_cache", size_callback=len
-        )  # type: LruCache[str, List[Tuple[str, int]]]
+        )
 
         self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
 
@@ -137,10 +137,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         initial_events = set(event_ids)
 
         # All the events that we've found that are reachable from the events.
-        seen_events = set()  # type: Set[str]
+        seen_events: Set[str] = set()
 
         # A map from chain ID to max sequence number of the given events.
-        event_chains = {}  # type: Dict[int, int]
+        event_chains: Dict[int, int] = {}
 
         sql = """
             SELECT event_id, chain_id, sequence_number
@@ -182,7 +182,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         """
 
         # A map from chain ID to max sequence number *reachable* from any event ID.
-        chains = {}  # type: Dict[int, int]
+        chains: Dict[int, int] = {}
 
         # Add all linked chains reachable from initial set of chains.
         for batch in batch_iter(event_chains, 1000):
@@ -353,14 +353,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         initial_events = set(state_sets[0]).union(*state_sets[1:])
 
         # Map from event_id -> (chain ID, seq no)
-        chain_info = {}  # type: Dict[str, Tuple[int, int]]
+        chain_info: Dict[str, Tuple[int, int]] = {}
 
         # Map from chain ID -> seq no -> event Id
-        chain_to_event = {}  # type: Dict[int, Dict[int, str]]
+        chain_to_event: Dict[int, Dict[int, str]] = {}
 
         # All the chains that we've found that are reachable from the state
         # sets.
-        seen_chains = set()  # type: Set[int]
+        seen_chains: Set[int] = set()
 
         sql = """
             SELECT event_id, chain_id, sequence_number
@@ -392,9 +392,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         # Corresponds to `state_sets`, except as a map from chain ID to max
         # sequence number reachable from the state set.
-        set_to_chain = []  # type: List[Dict[int, int]]
+        set_to_chain: List[Dict[int, int]] = []
         for state_set in state_sets:
-            chains = {}  # type: Dict[int, int]
+            chains: Dict[int, int] = {}
             set_to_chain.append(chains)
 
             for event_id in state_set:
@@ -446,7 +446,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         # Mapping from chain ID to the range of sequence numbers that should be
         # pulled from the database.
-        chain_to_gap = {}  # type: Dict[int, Tuple[int, int]]
+        chain_to_gap: Dict[int, Tuple[int, int]] = {}
 
         for chain_id in seen_chains:
             min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
@@ -555,7 +555,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         }
 
         # The sorted list of events whose auth chains we should walk.
-        search = []  # type: List[Tuple[int, str]]
+        search: List[Tuple[int, str]] = []
 
         # We need to get the depth of the initial events for sorting purposes.
         sql = """
@@ -578,7 +578,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         search.sort()
 
         # Map from event to its auth events
-        event_to_auth_events = {}  # type: Dict[str, Set[str]]
+        event_to_auth_events: Dict[str, Set[str]] = {}
 
         base_sql = """
             SELECT a.event_id, auth_id, depth
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d1237c65cc..55caa6bbe7 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -759,7 +759,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # object because we might not have the same amount of rows in each of them. To do
         # this, we use a dict indexed on the user ID and room ID to make it easier to
         # populate.
-        summaries = {}  # type: Dict[Tuple[str, str], _EventPushSummary]
+        summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
         for row in txn:
             summaries[(row[0], row[1])] = _EventPushSummary(
                 unread_count=row[2],
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 08c580b0dc..ec8579b9ad 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -109,10 +109,8 @@ class PersistEventsStore:
 
         # Ideally we'd move these ID gens here, unfortunately some other ID
         # generators are chained off them so doing so is a bit of a PITA.
-        self._backfill_id_gen = (
-            self.store._backfill_id_gen
-        )  # type: MultiWriterIdGenerator
-        self._stream_id_gen = self.store._stream_id_gen  # type: MultiWriterIdGenerator
+        self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
+        self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
 
         # This should only exist on instances that are configured to write
         assert (
@@ -221,7 +219,7 @@ class PersistEventsStore:
         Returns:
             Filtered event ids
         """
-        results = []  # type: List[str]
+        results: List[str] = []
 
         def _get_events_which_are_prevs_txn(txn, batch):
             sql = """
@@ -508,7 +506,7 @@ class PersistEventsStore:
         """
 
         # Map from event ID to chain ID/sequence number.
-        chain_map = {}  # type: Dict[str, Tuple[int, int]]
+        chain_map: Dict[str, Tuple[int, int]] = {}
 
         # Set of event IDs to calculate chain ID/seq numbers for.
         events_to_calc_chain_id_for = set(event_to_room_id)
@@ -817,8 +815,8 @@ class PersistEventsStore:
         #      new chain if the sequence number has already been allocated.
         #
 
-        existing_chains = set()  # type: Set[int]
-        tree = []  # type: List[Tuple[str, Optional[str]]]
+        existing_chains: Set[int] = set()
+        tree: List[Tuple[str, Optional[str]]] = []
 
         # We need to do this in a topologically sorted order as we want to
         # generate chain IDs/sequence numbers of an event's auth events before
@@ -848,7 +846,7 @@ class PersistEventsStore:
         )
         txn.execute(sql % (clause,), args)
 
-        chain_to_max_seq_no = {row[0]: row[1] for row in txn}  # type: Dict[Any, int]
+        chain_to_max_seq_no: Dict[Any, int] = {row[0]: row[1] for row in txn}
 
         # Allocate the new events chain ID/sequence numbers.
         #
@@ -858,8 +856,8 @@ class PersistEventsStore:
         # number of new chain IDs in one call, replacing all temporary
         # objects with real allocated chain IDs.
 
-        unallocated_chain_ids = set()  # type: Set[object]
-        new_chain_tuples = {}  # type: Dict[str, Tuple[Any, int]]
+        unallocated_chain_ids: Set[object] = set()
+        new_chain_tuples: Dict[str, Tuple[Any, int]] = {}
         for event_id, auth_event_id in tree:
             # If we reference an auth_event_id we fetch the allocated chain ID,
             # either from the existing `chain_map` or the newly generated
@@ -870,7 +868,7 @@ class PersistEventsStore:
                 if not existing_chain_id:
                     existing_chain_id = chain_map[auth_event_id]
 
-            new_chain_tuple = None  # type: Optional[Tuple[Any, int]]
+            new_chain_tuple: Optional[Tuple[Any, int]] = None
             if existing_chain_id:
                 # We found a chain ID/sequence number candidate, check its
                 # not already taken.
@@ -897,9 +895,9 @@ class PersistEventsStore:
         )
 
         # Map from potentially temporary chain ID to real chain ID
-        chain_id_to_allocated_map = dict(
+        chain_id_to_allocated_map: Dict[Any, int] = dict(
             zip(unallocated_chain_ids, newly_allocated_chain_ids)
-        )  # type: Dict[Any, int]
+        )
         chain_id_to_allocated_map.update((c, c) for c in existing_chains)
 
         return {
@@ -1175,9 +1173,9 @@ class PersistEventsStore:
         Returns:
             list[(EventBase, EventContext)]: filtered list
         """
-        new_events_and_contexts = (
-            OrderedDict()
-        )  # type: OrderedDict[str, Tuple[EventBase, EventContext]]
+        new_events_and_contexts: OrderedDict[
+            str, Tuple[EventBase, EventContext]
+        ] = OrderedDict()
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
             if prev_event_context:
@@ -1205,7 +1203,7 @@ class PersistEventsStore:
                 we are persisting
             backfilled (bool): True if the events were backfilled
         """
-        depth_updates = {}  # type: Dict[str, int]
+        depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -1885,7 +1883,7 @@ class PersistEventsStore:
                 ),
             )
 
-            room_to_event_ids = {}  # type: Dict[str, List[str]]
+            room_to_event_ids: Dict[str, List[str]] = {}
             for e, _ in events_and_contexts:
                 room_to_event_ids.setdefault(e.room_id, []).append(e.event_id)
 
@@ -2012,7 +2010,7 @@ class PersistEventsStore:
 
         Forward extremities are handled when we first start persisting the events.
         """
-        events_by_room = {}  # type: Dict[str, List[EventBase]]
+        events_by_room: Dict[str, List[EventBase]] = {}
         for ev in events:
             events_by_room.setdefault(ev.room_id, []).append(ev)
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 29f33bac55..6fcb2b8353 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -960,9 +960,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         event_to_types = {row[0]: (row[1], row[2]) for row in rows}
 
         # Calculate the new last position we've processed up to.
-        new_last_depth = rows[-1][3] if rows else last_depth  # type: int
-        new_last_stream = rows[-1][4] if rows else last_stream  # type: int
-        new_last_room_id = rows[-1][5] if rows else ""  # type: str
+        new_last_depth: int = rows[-1][3] if rows else last_depth
+        new_last_stream: int = rows[-1][4] if rows else last_stream
+        new_last_room_id: str = rows[-1][5] if rows else ""
 
         # Map from room_id to last depth/stream_ordering processed for the room,
         # excluding the last room (which we're likely still processing). We also
@@ -989,7 +989,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             retcols=("event_id", "auth_id"),
         )
 
-        event_to_auth_chain = {}  # type: Dict[str, List[str]]
+        event_to_auth_chain: Dict[str, List[str]] = {}
         for row in auth_events:
             event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 403a5ddaba..3c86adab56 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1365,10 +1365,10 @@ class EventsWorkerStore(SQLBaseStore):
         # we need to make sure that, for every stream id in the results, we get *all*
         # the rows with that stream id.
 
-        rows = await self.db_pool.runInteraction(
+        rows: List[Tuple] = await self.db_pool.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
-        )  # type: List[Tuple]
+        )
 
         # if we've got fewer rows than the limit, we're good
         if len(rows) < target_row_count:
@@ -1469,7 +1469,7 @@ class EventsWorkerStore(SQLBaseStore):
         """
 
         mapping = {}
-        txn_id_to_event = {}  # type: Dict[Tuple[str, int, str], str]
+        txn_id_to_event: Dict[Tuple[str, int, str], str] = {}
 
         for event in events:
             token_id = getattr(event.internal_metadata, "token_id", None)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index eb4841830d..664c65dac5 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -115,7 +115,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         logger.info("[purge] looking for events to delete")
 
         should_delete_expr = "state_key IS NULL"
-        should_delete_params = ()  # type: Tuple[Any, ...]
+        should_delete_params: Tuple[Any, ...] = ()
         if not delete_local_events:
             should_delete_expr += " AND event_id NOT LIKE ?"
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index db52176337..a7fb8cd848 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -79,9 +79,9 @@ class PushRulesWorkerStore(
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen = StreamIdGenerator(
-                db_conn, "push_rules_stream", "stream_id"
-            )  # type: Union[StreamIdGenerator, SlavedIdTracker]
+            self._push_rules_stream_id_gen: Union[
+                StreamIdGenerator, SlavedIdTracker
+            ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e31c5864ac..6ad1a0cf7f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1744,7 +1744,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
             items = keyvalues.items()
             where_clause = " AND ".join(k + " = ?" for k, _ in items)
-            values = [v for _, v in items]  # type: List[Union[str, int]]
+            values: List[Union[str, int]] = [v for _, v in items]
             # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
             # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
             # clause and values before we handle that. This seems to be only used in the "set password" handler.
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 7581c7d3ff..959f13de47 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1085,9 +1085,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
         # then filtering the results.
         if from_token.topological is not None:
-            from_bound = (
-                from_token.as_historical_tuple()
-            )  # type: Tuple[Optional[int], int]
+            from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
         elif direction == "b":
             from_bound = (
                 None,
@@ -1099,7 +1097,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                 from_token.stream,
             )
 
-        to_bound = None  # type: Optional[Tuple[Optional[int], int]]
+        to_bound: Optional[Tuple[Optional[int], int]] = None
         if to_token:
             if to_token.topological is not None:
                 to_bound = to_token.as_historical_tuple()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 1d62c6140f..f93ff0a545 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -42,7 +42,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
         )
 
-        tags_by_room = {}  # type: Dict[str, Dict[str, JsonDict]]
+        tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
         for row in rows:
             room_tags = tags_by_room.setdefault(row["room_id"], {})
             room_tags[row["tag"]] = db_to_json(row["content"])
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 22c05cdde7..38bfdf5dad 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -224,12 +224,12 @@ class UIAuthWorkerStore(SQLBaseStore):
         self, txn: LoggingTransaction, session_id: str, key: str, value: Any
     ):
         # Get the current value.
-        result = self.db_pool.simple_select_one_txn(
+        result: Dict[str, Any] = self.db_pool.simple_select_one_txn(  # type: ignore
             txn,
             table="ui_auth_sessions",
             keyvalues={"session_id": session_id},
             retcols=("serverdict",),
-        )  # type: Dict[str, Any]  # type: ignore
+        )
 
         # Update it and add it back to the database.
         serverdict = db_to_json(result["serverdict"])
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 051095fea9..a39877f0d5 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -307,7 +307,7 @@ class EventsPersistenceStorage:
             matched the transcation ID; the existing event is returned in such
             a case.
         """
-        partitioned = {}  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
+        partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
@@ -384,7 +384,7 @@ class EventsPersistenceStorage:
             A dictionary of event ID to event ID we didn't persist as we already
             had another event persisted with the same TXN ID.
         """
-        replaced_events = {}  # type: Dict[str, str]
+        replaced_events: Dict[str, str] = {}
         if not events_and_contexts:
             return replaced_events
 
@@ -440,16 +440,14 @@ class EventsPersistenceStorage:
             # Set of remote users which were in rooms the server has left. We
             # should check if we still share any rooms and if not we mark their
             # device lists as stale.
-            potentially_left_users = set()  # type: Set[str]
+            potentially_left_users: Set[str] = set()
 
             if not backfilled:
                 with Measure(self._clock, "_calculate_state_and_extrem"):
                     # Work out the new "current state" for each room.
                     # We do this by working out what the new extremities are and then
                     # calculating the state from that.
-                    events_by_room = (
-                        {}
-                    )  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
+                    events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
                     for event, context in chunk:
                         events_by_room.setdefault(event.room_id, []).append(
                             (event, context)
@@ -622,9 +620,9 @@ class EventsPersistenceStorage:
         )
 
         # Remove any events which are prev_events of any existing events.
-        existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
-            result
-        )  # type: Collection[str]
+        existing_prevs: Collection[
+            str
+        ] = await self.persist_events_store._get_events_which_are_prevs(result)
         result.difference_update(existing_prevs)
 
         # Finally handle the case where the new events have soft-failed prev
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 683e5e3b90..82a7686df0 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -256,7 +256,7 @@ def _setup_new_database(
         for database in databases
     )
 
-    directory_entries = []  # type: List[_DirectoryListing]
+    directory_entries: List[_DirectoryListing] = []
     for directory in directories:
         directory_entries.extend(
             _DirectoryListing(file_name, os.path.join(directory, file_name))
@@ -424,10 +424,10 @@ def _upgrade_existing_database(
             directories.append(os.path.join(schema_path, database, "delta", str(v)))
 
         # Used to check if we have any duplicate file names
-        file_name_counter = Counter()  # type: CounterType[str]
+        file_name_counter: CounterType[str] = Counter()
 
         # Now find which directories have anything of interest.
-        directory_entries = []  # type: List[_DirectoryListing]
+        directory_entries: List[_DirectoryListing] = []
         for directory in directories:
             logger.debug("Looking for schema deltas in %s", directory)
             try:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index c9dce726cb..f8fbba9d38 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -91,7 +91,7 @@ class StateFilter:
         Returns:
             The new state filter.
         """
-        type_dict = {}  # type: Dict[str, Optional[Set[str]]]
+        type_dict: Dict[str, Optional[Set[str]]] = {}
         for typ, s in types:
             if typ in type_dict:
                 if type_dict[typ] is None:
@@ -194,7 +194,7 @@ class StateFilter:
         """
 
         where_clause = ""
-        where_args = []  # type: List[str]
+        where_args: List[str] = []
 
         if self.is_full():
             return where_clause, where_args
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f1e62f9e85..c768fdea56 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -112,7 +112,7 @@ class StreamIdGenerator:
         # insertion ordering will ensure its in the correct ordering.
         #
         # The key and values are the same, but we never look at the values.
-        self._unfinished_ids = OrderedDict()  # type: OrderedDict[int, int]
+        self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
 
     def get_next(self):
         """
@@ -236,15 +236,15 @@ class MultiWriterIdGenerator:
         # Note: If we are a negative stream then we still store all the IDs as
         # positive to make life easier for us, and simply negate the IDs when we
         # return them.
-        self._current_positions = {}  # type: Dict[str, int]
+        self._current_positions: Dict[str, int] = {}
 
         # Set of local IDs that we're still processing. The current position
         # should be less than the minimum of this set (if not empty).
-        self._unfinished_ids = set()  # type: Set[int]
+        self._unfinished_ids: Set[int] = set()
 
         # Set of local IDs that we've processed that are larger than the current
         # position, due to there being smaller unpersisted IDs.
-        self._finished_ids = set()  # type: Set[int]
+        self._finished_ids: Set[int] = set()
 
         # We track the max position where we know everything before has been
         # persisted. This is done by a) looking at the min across all instances
@@ -265,7 +265,7 @@ class MultiWriterIdGenerator:
         self._persisted_upto_position = (
             min(self._current_positions.values()) if self._current_positions else 1
         )
-        self._known_persisted_positions = []  # type: List[int]
+        self._known_persisted_positions: List[int] = []
 
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
@@ -465,7 +465,7 @@ class MultiWriterIdGenerator:
             self._unfinished_ids.discard(next_id)
             self._finished_ids.add(next_id)
 
-            new_cur = None  # type: Optional[int]
+            new_cur: Optional[int] = None
 
             if self._unfinished_ids:
                 # If there are unfinished IDs then the new position will be the
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 30b6b8e0ca..bb33e04fb1 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -208,10 +208,10 @@ class LocalSequenceGenerator(SequenceGenerator):
                  get_next_id_txn; should return the curreent maximum id
         """
         # the callback. this is cleared after it is called, so that it can be GCed.
-        self._callback = get_first_callback  # type: Optional[GetFirstCallbackType]
+        self._callback: Optional[GetFirstCallbackType] = get_first_callback
 
         # The current max value, or None if we haven't looked in the DB yet.
-        self._current_max_id = None  # type: Optional[int]
+        self._current_max_id: Optional[int] = None
         self._lock = threading.Lock()
 
     def get_next_id_txn(self, txn: Cursor) -> int:
@@ -274,7 +274,7 @@ def build_sequence_generator(
             `check_consistency` details.
     """
     if isinstance(database_engine, PostgresEngine):
-        seq = PostgresSequenceGenerator(sequence_name)  # type: SequenceGenerator
+        seq: SequenceGenerator = PostgresSequenceGenerator(sequence_name)
     else:
         seq = LocalSequenceGenerator(get_first_callback)
 
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 061102c3c8..014db1355b 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -257,7 +257,7 @@ class Linearizer:
             max_count: The maximum number of concurrent accesses
         """
         if name is None:
-            self.name = id(self)  # type: Union[str, int]
+            self.name: Union[str, int] = id(self)
         else:
             self.name = name
 
@@ -269,7 +269,7 @@ class Linearizer:
         self.max_count = max_count
 
         # key_to_defer is a map from the key to a _LinearizerEntry.
-        self.key_to_defer = {}  # type: Dict[Hashable, _LinearizerEntry]
+        self.key_to_defer: Dict[Hashable, _LinearizerEntry] = {}
 
     def is_queued(self, key: Hashable) -> bool:
         """Checks whether there is a process queued up waiting"""
@@ -409,10 +409,10 @@ class ReadWriteLock:
 
     def __init__(self):
         # Latest readers queued
-        self.key_to_current_readers = {}  # type: Dict[str, Set[defer.Deferred]]
+        self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
 
         # Latest writer queued
-        self.key_to_current_writer = {}  # type: Dict[str, defer.Deferred]
+        self.key_to_current_writer: Dict[str, defer.Deferred] = {}
 
     async def read(self, key: str) -> ContextManager:
         new_defer = defer.Deferred()
diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py
index 8fd5bfb69b..274cea7eb7 100644
--- a/synapse/util/batching_queue.py
+++ b/synapse/util/batching_queue.py
@@ -93,11 +93,11 @@ class BatchingQueue(Generic[V, R]):
         self._clock = clock
 
         # The set of keys currently being processed.
-        self._processing_keys = set()  # type: Set[Hashable]
+        self._processing_keys: Set[Hashable] = set()
 
         # The currently pending batch of values by key, with a Deferred to call
         # with the result of the corresponding `_process_batch_callback` call.
-        self._next_values = {}  # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]]
+        self._next_values: Dict[Hashable, List[Tuple[V, defer.Deferred]]] = {}
 
         # The function to call with batches of values.
         self._process_batch_callback = process_batch_callback
@@ -108,9 +108,7 @@ class BatchingQueue(Generic[V, R]):
 
         number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
 
-        self._number_in_flight_metric = number_in_flight.labels(
-            self._name
-        )  # type: Gauge
+        self._number_in_flight_metric: Gauge = number_in_flight.labels(self._name)
 
     async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
         """Adds the value to the queue with the given key, returning the result
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index ca36f07c20..9012034b7a 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
 TRACK_MEMORY_USAGE = False
 
 
-caches_by_name = {}  # type: Dict[str, Sized]
-collectors_by_name = {}  # type: Dict[str, CacheMetric]
+caches_by_name: Dict[str, Sized] = {}
+collectors_by_name: Dict[str, "CacheMetric"] = {}
 
 cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
 cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index a301c9e89b..891bee0b33 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -63,9 +63,9 @@ class CachedCall(Generic[TV]):
             f: The underlying function. Only one call to this function will be alive
                 at once (per instance of CachedCall)
         """
-        self._callable = f  # type: Optional[Callable[[], Awaitable[TV]]]
-        self._deferred = None  # type: Optional[Deferred]
-        self._result = None  # type: Union[None, Failure, TV]
+        self._callable: Optional[Callable[[], Awaitable[TV]]] = f
+        self._deferred: Optional[Deferred] = None
+        self._result: Union[None, Failure, TV] = None
 
     async def get(self) -> TV:
         """Kick off the call if necessary, and return the result"""
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1044139119..8c6fafc677 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -80,25 +80,25 @@ class DeferredCache(Generic[KT, VT]):
         cache_type = TreeCache if tree else dict
 
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
-        self._pending_deferred_cache = (
-            cache_type()
-        )  # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
+        self._pending_deferred_cache: Union[
+            TreeCache, "MutableMapping[KT, CacheEntry]"
+        ] = cache_type()
 
         def metrics_cb():
             cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
 
         # cache is used for completed results and maps to the result itself, rather than
         # a Deferred.
-        self.cache = LruCache(
+        self.cache: LruCache[KT, VT] = LruCache(
             max_size=max_entries,
             cache_name=name,
             cache_type=cache_type,
             size_callback=(lambda d: len(d) or 1) if iterable else None,
             metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
-        )  # type: LruCache[KT, VT]
+        )
 
-        self.thread = None  # type: Optional[threading.Thread]
+        self.thread: Optional[threading.Thread] = None
 
     @property
     def max_entries(self):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index d77e8edeea..1e8e6b1d01 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -46,17 +46,17 @@ F = TypeVar("F", bound=Callable[..., Any])
 
 
 class _CachedFunction(Generic[F]):
-    invalidate = None  # type: Any
-    invalidate_all = None  # type: Any
-    prefill = None  # type: Any
-    cache = None  # type: Any
-    num_args = None  # type: Any
+    invalidate: Any = None
+    invalidate_all: Any = None
+    prefill: Any = None
+    cache: Any = None
+    num_args: Any = None
 
-    __name__ = None  # type: str
+    __name__: str
 
     # Note: This function signature is actually fiddled with by the synapse mypy
     # plugin to a) make it a bound method, and b) remove any `cache_context` arg.
-    __call__ = None  # type: F
+    __call__: F
 
 
 class _CacheDescriptorBase:
@@ -115,8 +115,8 @@ class _CacheDescriptorBase:
 
 
 class _LruCachedFunction(Generic[F]):
-    cache = None  # type: LruCache[CacheKey, Any]
-    __call__ = None  # type: F
+    cache: LruCache[CacheKey, Any]
+    __call__: F
 
 
 def lru_cache(
@@ -180,10 +180,10 @@ class LruCacheDescriptor(_CacheDescriptorBase):
         self.max_entries = max_entries
 
     def __get__(self, obj, owner):
-        cache = LruCache(
+        cache: LruCache[CacheKey, Any] = LruCache(
             cache_name=self.orig.__name__,
             max_size=self.max_entries,
-        )  # type: LruCache[CacheKey, Any]
+        )
 
         get_cache_key = self.cache_key_builder
         sentinel = LruCacheDescriptor._Sentinel.sentinel
@@ -271,12 +271,12 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         self.iterable = iterable
 
     def __get__(self, obj, owner):
-        cache = DeferredCache(
+        cache: DeferredCache[CacheKey, Any] = DeferredCache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             tree=self.tree,
             iterable=self.iterable,
-        )  # type: DeferredCache[CacheKey, Any]
+        )
 
         get_cache_key = self.cache_key_builder
 
@@ -359,7 +359,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj, objtype=None):
         cached_method = getattr(obj, self.cached_method_name)
-        cache = cached_method.cache  # type: DeferredCache[CacheKey, Any]
+        cache: DeferredCache[CacheKey, Any] = cached_method.cache
         num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
@@ -472,15 +472,15 @@ class _CacheContext:
 
     Cache = Union[DeferredCache, LruCache]
 
-    _cache_context_objects = (
-        WeakValueDictionary()
-    )  # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
+    _cache_context_objects: """WeakValueDictionary[
+        Tuple["_CacheContext.Cache", CacheKey], "_CacheContext"
+    ]""" = WeakValueDictionary()
 
     def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
         self._cache = cache
         self._cache_key = cache_key
 
-    def invalidate(self):  # type: () -> None
+    def invalidate(self) -> None:
         """Invalidates the cache entry referred to by the context."""
         self._cache.invalidate(self._cache_key)
 
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 56d94d96ce..3f852edd7f 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -62,13 +62,13 @@ class DictionaryCache(Generic[KT, DKT]):
     """
 
     def __init__(self, name: str, max_entries: int = 1000):
-        self.cache = LruCache(
+        self.cache: LruCache[KT, DictionaryEntry] = LruCache(
             max_size=max_entries, cache_name=name, size_callback=len
-        )  # type: LruCache[KT, DictionaryEntry]
+        )
 
         self.name = name
         self.sequence = 0
-        self.thread = None  # type: Optional[threading.Thread]
+        self.thread: Optional[threading.Thread] = None
 
     def check_thread(self) -> None:
         expected_thread = self.thread
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index ac47a31cd7..bde16b8577 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -27,7 +27,7 @@ from synapse.util.caches import register_cache
 logger = logging.getLogger(__name__)
 
 
-SENTINEL = object()  # type: Any
+SENTINEL: Any = object()
 
 
 T = TypeVar("T")
@@ -71,7 +71,7 @@ class ExpiringCache(Generic[KT, VT]):
         self._expiry_ms = expiry_ms
         self._reset_expiry_on_get = reset_expiry_on_get
 
-        self._cache = OrderedDict()  # type: OrderedDict[KT, _CacheEntry]
+        self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict()
 
         self.iterable = iterable
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4b9d0433ff..efeba0cb96 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -226,7 +226,7 @@ class _Node:
         # footprint down. Storing `None` is free as its a singleton, while empty
         # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
         # thing and used sets).
-        self.callbacks = None  # type: Optional[List[Callable[[], None]]]
+        self.callbacks: Optional[List[Callable[[], None]]] = None
 
         self.add_callbacks(callbacks)
 
@@ -362,15 +362,15 @@ class LruCache(Generic[KT, VT]):
 
         # register_cache might call our "set_cache_factor" callback; there's nothing to
         # do yet when we get resized.
-        self._on_resize = None  # type: Optional[Callable[[],None]]
+        self._on_resize: Optional[Callable[[], None]] = None
 
         if cache_name is not None:
-            metrics = register_cache(
+            metrics: Optional[CacheMetric] = register_cache(
                 "lru_cache",
                 cache_name,
                 self,
                 collect_callback=metrics_collection_callback,
-            )  # type: Optional[CacheMetric]
+            )
         else:
             metrics = None
 
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 34c662c4db..ed7204336f 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -66,7 +66,7 @@ class ResponseCache(Generic[KV]):
         # This is poorly-named: it includes both complete and incomplete results.
         # We keep complete results rather than switching to absolute values because
         # that makes it easier to cache Failure results.
-        self.pending_result_cache = {}  # type: Dict[KV, ObservableDeferred]
+        self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
 
         self.clock = clock
         self.timeout_sec = timeout_ms / 1000.0
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index e81e468899..3a41a8baa6 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -45,10 +45,10 @@ class StreamChangeCache:
     ):
         self._original_max_size = max_size
         self._max_size = math.floor(max_size)
-        self._entity_to_key = {}  # type: Dict[EntityType, int]
+        self._entity_to_key: Dict[EntityType, int] = {}
 
         # map from stream id to the a set of entities which changed at that stream id.
-        self._cache = SortedDict()  # type: SortedDict[int, Set[EntityType]]
+        self._cache: SortedDict[int, Set[EntityType]] = SortedDict()
 
         # the earliest stream_pos for which we can reliably answer
         # get_all_entities_changed. In other words, one less than the earliest
@@ -155,7 +155,7 @@ class StreamChangeCache:
         if stream_pos < self._earliest_known_stream_pos:
             return None
 
-        changed_entities = []  # type: List[EntityType]
+        changed_entities: List[EntityType] = []
 
         for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
             changed_entities.extend(self._cache[k])
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index c276107d56..46afe3f934 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -23,7 +23,7 @@ from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
 
-SENTINEL = object()  # type: Any
+SENTINEL: Any = object()
 
 T = TypeVar("T")
 KT = TypeVar("KT")
@@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):
 
     def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
         # map from key to _CacheEntry
-        self._data = {}  # type: Dict[KT, _CacheEntry]
+        self._data: Dict[KT, _CacheEntry] = {}
 
         # the _CacheEntries, sorted by expiry time
-        self._expiry_list = SortedList()  # type: SortedList[_CacheEntry]
+        self._expiry_list: SortedList[_CacheEntry] = SortedList()
 
         self._timer = timer
 
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 886afa9d19..8ac3eab2f5 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -68,7 +68,7 @@ def sorted_topologically(
     # This is implemented by Kahn's algorithm.
 
     degree_map = {node: 0 for node in nodes}
-    reverse_graph = {}  # type: Dict[T, Set[T]]
+    reverse_graph: Dict[T, Set[T]] = {}
 
     for node, edges in graph.items():
         if node not in degree_map:
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index f6ebfd7e7d..d1f76e3dc5 100644
--- a/synapse/util/macaroons.py
+++ b/synapse/util/macaroons.py
@@ -39,7 +39,7 @@ def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
              caveat in the macaroon, or if the caveat was not found in the macaroon.
     """
     prefix = key + " = "
-    result = None  # type: Optional[str]
+    result: Optional[str] = None
     for caveat in macaroon.caveats:
         if not caveat.caveat_id.startswith(prefix):
             continue
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 45353d41c5..1b82dca81b 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -124,7 +124,7 @@ class Measure:
             assert isinstance(curr_context, LoggingContext)
             parent_context = curr_context
         self._logging_context = LoggingContext(str(curr_context), parent_context)
-        self.start = None  # type: Optional[int]
+        self.start: Optional[int] = None
 
     def __enter__(self) -> "Measure":
         if self.start is not None:
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index eed0291cae..99f01e325c 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -41,7 +41,7 @@ def do_patch():
         @functools.wraps(f)
         def wrapped(*args, **kwargs):
             start_context = current_context()
-            changes = []  # type: List[str]
+            changes: List[str] = []
             orig = orig_inline_callbacks(_check_yield_points(f, changes))
 
             try:
@@ -131,7 +131,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
         gen = f(*args, **kwargs)
 
         last_yield_line_no = gen.gi_frame.f_lineno
-        result = None  # type: Any
+        result: Any = None
         while True:
             expected_context = current_context()