diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 82e9ef02d2..6d45a8a9f6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -36,7 +36,17 @@ what sort order was used:
"""
import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
import attr
from frozendict import frozendict
@@ -732,7 +742,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A tuple of (stream ordering, topological ordering, event_id)
"""
- def _f(txn):
+ def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]:
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
@@ -742,7 +752,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" LIMIT 1"
)
txn.execute(sql, (room_id, stream_ordering))
- return txn.fetchone()
+ return cast(Optional[Tuple[int, int, str]], txn.fetchone())
return await self.db_pool.runInteraction(
"get_room_event_before_stream_ordering", _f
@@ -839,7 +849,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@staticmethod
def _set_before_and_after(
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
- ):
+ ) -> None:
"""Inserts ordering information to events' internal metadata from
the DB rows.
@@ -985,7 +995,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
the `current_id`).
"""
- def get_all_new_events_stream_txn(txn):
+ def get_all_new_events_stream_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, List[str]]:
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
@@ -1331,7 +1343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_id_for_instance(self, instance_name: str) -> int:
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
- def _get_id_for_instance_txn(txn):
+ def _get_id_for_instance_txn(txn: LoggingTransaction) -> int:
instance_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="instance_map",
|