summary refs log tree commit diff
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2021-11-26 18:41:31 +0000
committerGitHub <noreply@github.com>2021-11-26 18:41:31 +0000
commitffd858aa68239aeaf06591d94c0ab1b3c185440f (patch)
treed8802699acbc78790551e6c232f12650b65cab75
parentSupport expiry of refresh tokens and expiry of the overall session when refre... (diff)
downloadsynapse-ffd858aa68239aeaf06591d94c0ab1b3c185440f.tar.xz
Add type hints to `synapse/storage/databases/main/events_worker.py` (#11411)
Also refactor the stream ID trackers/generators a bit and try to
document them better.
-rw-r--r--changelog.d/11411.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py22
-rw-r--r--synapse/replication/slave/storage/push_rule.py4
-rw-r--r--synapse/replication/tcp/streams/events.py6
-rw-r--r--synapse/state/__init__.py2
-rw-r--r--synapse/state/v1.py3
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/databases/main/events.py29
-rw-r--r--synapse/storage/databases/main/events_worker.py218
-rw-r--r--synapse/storage/databases/main/push_rule.py11
-rw-r--r--synapse/storage/util/id_generators.py116
-rw-r--r--tests/replication/test_sharded_event_persister.py6
13 files changed, 255 insertions, 171 deletions
diff --git a/changelog.d/11411.misc b/changelog.d/11411.misc
new file mode 100644
index 0000000000..86594a332d
--- /dev/null
+++ b/changelog.d/11411.misc
@@ -0,0 +1 @@
+Add type hints to storage classes.
diff --git a/mypy.ini b/mypy.ini
index bc4f59154d..eb3976e74c 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -33,7 +33,6 @@ exclude = (?x)
    |synapse/storage/databases/main/event_federation.py
    |synapse/storage/databases/main/event_push_actions.py
    |synapse/storage/databases/main/events_bg_updates.py
-   |synapse/storage/databases/main/events_worker.py
    |synapse/storage/databases/main/group_server.py
    |synapse/storage/databases/main/metrics.py
    |synapse/storage/databases/main/monthly_active_users.py
@@ -184,6 +183,9 @@ disallow_untyped_defs = True
 [mypy-synapse.storage.databases.main.directory]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.events_worker]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.room_batch]
 disallow_untyped_defs = True
 
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 8c1bf9227a..fa132d10b4 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -14,10 +14,18 @@
 from typing import List, Optional, Tuple
 
 from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.util.id_generators import _load_current_id
+from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
 
 
-class SlavedIdTracker:
+class SlavedIdTracker(AbstractStreamIdTracker):
+    """Tracks the "current" stream ID of a stream with a single writer.
+
+    See `AbstractStreamIdTracker` for more details.
+
+    Note that this class does not work correctly when there are multiple
+    writers.
+    """
+
     def __init__(
         self,
         db_conn: LoggingDatabaseConnection,
@@ -36,17 +44,7 @@ class SlavedIdTracker:
         self._current = (max if self.step > 0 else min)(self._current, new_id)
 
     def get_current_token(self) -> int:
-        """
-
-        Returns:
-            int
-        """
         return self._current
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer.
-
-        For streams with single writers this is equivalent to
-        `get_current_token`.
-        """
         return self.get_current_token()
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 4d5f862862..7541e21de9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 
@@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
         return self._push_rules_stream_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        # We assert this for the benefit of mypy
-        assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
-
         if stream_name == PushRulesStream.NAME:
             self._push_rules_stream_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index a030e9299e..a390cfcb74 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import heapq
 from collections.abc import Iterable
-from typing import TYPE_CHECKING, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Tuple, Type
 
 import attr
 
@@ -157,7 +157,7 @@ class EventsStream(Stream):
 
         # now we fetch up to that many rows from the events table
 
-        event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
+        event_rows = await self._store.get_all_new_forward_event_rows(
             instance_name, from_token, current_token, target_row_count
         )
 
@@ -191,7 +191,7 @@ class EventsStream(Stream):
         # finally, fetch the ex-outliers rows. We assume there are few enough of these
         # not to bother with the limit.
 
-        ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
+        ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
             instance_name, from_token, upper_limit
         )
 
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1605411b00..446204dbe5 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -764,7 +764,7 @@ class StateResolutionStore:
     store: "DataStore"
 
     def get_events(
-        self, event_ids: Iterable[str], allow_rejected: bool = False
+        self, event_ids: Collection[str], allow_rejected: bool = False
     ) -> Awaitable[Dict[str, EventBase]]:
         """Get events from the database
 
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 6edadea550..499a328201 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,6 +17,7 @@ import logging
 from typing import (
     Awaitable,
     Callable,
+    Collection,
     Dict,
     Iterable,
     List,
@@ -44,7 +45,7 @@ async def resolve_events_with_store(
     room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
-    state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+    state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
 ) -> StateMap[str]:
     """
     Args:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0623da9aa1..3056e64ff5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import DatabasePool
 from synapse.storage.types import Connection
-from synapse.types import StreamToken, get_domain_from_id
+from synapse.types import get_domain_from_id
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
         self,
         stream_name: str,
         instance_name: str,
-        token: StreamToken,
+        token: int,
         rows: Iterable[Any],
     ) -> None:
         pass
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 06832221ad..c3440de2cb 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 import itertools
 import logging
-from collections import OrderedDict, namedtuple
+from collections import OrderedDict
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.logging.utils import log_function
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
@@ -64,9 +65,6 @@ event_counter = Counter(
 )
 
 
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
 @attr.s(slots=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
@@ -108,16 +106,21 @@ class PersistEventsStore:
         self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
-        # Ideally we'd move these ID gens here, unfortunately some other ID
-        # generators are chained off them so doing so is a bit of a PITA.
-        self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
-        self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
-
         # This should only exist on instances that are configured to write
         assert (
             hs.get_instance_name() in hs.config.worker.writers.events
         ), "Can only instantiate EventsStore on master"
 
+        # Since we have been configured to write, we ought to have id generators,
+        # rather than id trackers.
+        assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
+        assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
+
+        # Ideally we'd move these ID gens here, unfortunately some other ID
+        # generators are chained off them so doing so is a bit of a PITA.
+        self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
+        self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+
     async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
@@ -1553,11 +1556,13 @@ class PersistEventsStore:
         for row in rows:
             event = ev_map[row["event_id"]]
             if not row["rejects"] and not row["redacts"]:
-                to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
+                to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
         def prefill():
             for cache_entry in to_prefill:
-                self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
+                self.store._get_event_cache.set(
+                    (cache_entry.event.event_id,), cache_entry
+                )
 
         txn.call_after(prefill)
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bcfe1c32..4cefc0a07e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
 import logging
 import threading
 from typing import (
+    TYPE_CHECKING,
+    Any,
     Collection,
     Container,
     Dict,
     Iterable,
     List,
+    NoReturn,
     Optional,
     Set,
     Tuple,
+    cast,
     overload,
 )
 
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
+    RoomVersion,
     RoomVersions,
 )
 from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
@@ -69,6 +82,9 @@ from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
 
 
 @attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
     event: EventBase
     redacted_event: Optional[EventBase]
 
@@ -129,7 +145,7 @@ class _EventRow:
     json: str
     internal_metadata: str
     format_version: Optional[int]
-    room_version_id: Optional[int]
+    room_version_id: Optional[str]
     rejected_reason: Optional[str]
     redactions: List[str]
     outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
     # options controlling this.
     USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
+        self._stream_id_gen: AbstractStreamIdTracker
+        self._backfill_id_gen: AbstractStreamIdTracker
         if isinstance(database.engine, PostgresEngine):
             # If we're using Postgres than we can use `MultiWriterIdGenerator`
             # regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
                 5 * 60 * 1000,
             )
 
-        self._get_event_cache = LruCache(
+        self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
         )
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
         # ID to cache entry. Note that the returned dict may not have the
         # requested event in it if the event isn't in the DB.
         self._current_event_fetches: Dict[
-            str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+            str, ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = {}
 
         self._event_fetch_lock = threading.Condition()
-        self._event_fetch_list = []
+        self._event_fetch_list: List[
+            Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+        ] = []
         self._event_fetch_ongoing = 0
         event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
 
         # We define this sequence here so that it can be referenced from both
         # the DataStore and PersistEventStore.
-        def get_chain_id_txn(txn):
+        def get_chain_id_txn(txn: Cursor) -> int:
             txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
-            return txn.fetchone()[0]
+            return cast(Tuple[int], txn.fetchone())[0]
 
         self.event_chain_id_gen = build_sequence_generator(
             db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
             id_column="chain_id",
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == EventsStream.NAME:
             self._stream_id_gen.advance(instance_name, token)
         elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
-        get_prev_content: bool = False,
-        allow_rejected: bool = False,
-        allow_none: Literal[False] = False,
-        check_room_id: Optional[str] = None,
+        get_prev_content: bool = ...,
+        allow_rejected: bool = ...,
+        allow_none: Literal[False] = ...,
+        check_room_id: Optional[str] = ...,
     ) -> EventBase:
         ...
 
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
-        get_prev_content: bool = False,
-        allow_rejected: bool = False,
-        allow_none: Literal[True] = False,
-        check_room_id: Optional[str] = None,
+        get_prev_content: bool = ...,
+        allow_rejected: bool = ...,
+        allow_none: Literal[True] = ...,
+        check_room_id: Optional[str] = ...,
     ) -> Optional[EventBase]:
         ...
 
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_events(
         self,
-        event_ids: Iterable[str],
+        event_ids: Collection[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def _get_events_from_cache_or_db(
         self, event_ids: Iterable[str], allow_rejected: bool = False
-    ) -> Dict[str, _EventCacheEntry]:
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the cache or the database.
 
         If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
         # same dict into itself N times).
         already_fetching_ids: Set[str] = set()
         already_fetching_deferreds: Set[
-            ObservableDeferred[Dict[str, _EventCacheEntry]]
+            ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = set()
 
         for event_id in missing_events_ids:
@@ -601,7 +632,7 @@ class EventsWorkerStore(SQLBaseStore):
             # function returning more events than requested, but that can happen
             # already due to `_get_events_from_db`).
             fetching_deferred: ObservableDeferred[
-                Dict[str, _EventCacheEntry]
+                Dict[str, EventCacheEntry]
             ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
             for event_id in missing_events_ids:
                 self._current_event_fetches[event_id] = fetching_deferred
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    def _invalidate_get_event_cache(self, event_id):
+    def _invalidate_get_event_cache(self, event_id: str) -> None:
         self._get_event_cache.invalidate((event_id,))
 
     def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
-    ) -> Dict[str, _EventCacheEntry]:
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch events from the caches.
 
         May return rejected events.
@@ -820,7 +851,7 @@ class EventsWorkerStore(SQLBaseStore):
                     for _, deferred in event_fetches_to_fail:
                         deferred.errback(exc)
 
-    def _fetch_loop(self, conn: Connection) -> None:
+    def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
         """
@@ -850,7 +881,9 @@ class EventsWorkerStore(SQLBaseStore):
             self._fetch_event_list(conn, event_list)
 
     def _fetch_event_list(
-        self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+        self,
+        conn: LoggingDatabaseConnection,
+        event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
     ) -> None:
         """Handle a load of requests from the _event_fetch_list queue
 
@@ -877,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
                 )
 
                 # We only want to resolve deferreds from the main thread
-                def fire():
+                def fire() -> None:
                     for _, d in event_list:
                         d.callback(row_dict)
 
@@ -887,16 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
                 logger.exception("do_fetch")
 
                 # We only want to resolve deferreds from the main thread
-                def fire(evs, exc):
-                    for _, d in evs:
+                def fire_errback(exc: Exception) -> None:
+                    for _, d in event_list:
                         d.errback(exc)
 
                 with PreserveLoggingContext():
-                    self.hs.get_reactor().callFromThread(fire, event_list, e)
+                    self.hs.get_reactor().callFromThread(fire_errback, e)
 
     async def _get_events_from_db(
-        self, event_ids: Iterable[str]
-    ) -> Dict[str, _EventCacheEntry]:
+        self, event_ids: Collection[str]
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the database.
 
         May return rejected events.
@@ -912,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
             map from event id to result. May return extra events which
             weren't asked for.
         """
-        fetched_events = {}
+        fetched_event_ids: Set[str] = set()
+        fetched_events: Dict[str, _EventRow] = {}
         events_to_fetch = event_ids
 
         while events_to_fetch:
             row_map = await self._enqueue_events(events_to_fetch)
 
             # we need to recursively fetch any redactions of those events
-            redaction_ids = set()
+            redaction_ids: Set[str] = set()
             for event_id in events_to_fetch:
                 row = row_map.get(event_id)
-                fetched_events[event_id] = row
+                fetched_event_ids.add(event_id)
                 if row:
+                    fetched_events[event_id] = row
                     redaction_ids.update(row.redactions)
 
-            events_to_fetch = redaction_ids.difference(fetched_events.keys())
+            events_to_fetch = redaction_ids.difference(fetched_event_ids)
             if events_to_fetch:
                 logger.debug("Also fetching redaction events %s", events_to_fetch)
 
         # build a map from event_id to EventBase
-        event_map = {}
+        event_map: Dict[str, EventBase] = {}
         for event_id, row in fetched_events.items():
-            if not row:
-                continue
             assert row.event_id == event_id
 
             rejected_reason = row.rejected_reason
@@ -962,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             room_version_id = row.room_version_id
 
+            room_version: Optional[RoomVersion]
             if not room_version_id:
                 # this should only happen for out-of-band membership events which
                 # arrived before #6983 landed. For all other events, we should have
@@ -1032,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         # finally, we can decide whether each one needs redacting, and build
         # the cache entries.
-        result_map = {}
+        result_map: Dict[str, EventCacheEntry] = {}
         for event_id, original_ev in event_map.items():
             redactions = fetched_events[event_id].redactions
             redacted_event = self._maybe_redact_event_row(
                 original_ev, redactions, event_map
             )
 
-            cache_entry = _EventCacheEntry(
+            cache_entry = EventCacheEntry(
                 event=original_ev, redacted_event=redacted_event
             )
 
@@ -1048,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+    async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
@@ -1061,7 +1095,7 @@ class EventsWorkerStore(SQLBaseStore):
             that weren't requested.
         """
 
-        events_d = defer.Deferred()
+        events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
         with self._event_fetch_lock:
             self._event_fetch_list.append((events, events_d))
             self._event_fetch_lock.notify()
@@ -1216,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
         # no valid redaction found for this event
         return None
 
-    async def have_events_in_timeline(self, event_ids):
+    async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
@@ -1245,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
             event_ids: events we are looking for
 
         Returns:
-            set[str]: The events we have already seen.
+            The set of events we have already seen.
         """
         res = await self._have_seen_events_dict(
             (room_id, event_id) for event_id in event_ids
@@ -1268,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
         }
         results = {x: True for x in cache_results}
 
-        def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+        def have_seen_events_txn(
+            txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+        ) -> None:
             # we deliberately do *not* query the database for room_id, to make the
             # query an index-only lookup on `events_event_id_key`.
             #
@@ -1294,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
         return results
 
     @cached(max_entries=100000, tree=True)
-    async def have_seen_event(self, room_id: str, event_id: str):
+    async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
         # this only exists for the benefit of the @cachedList descriptor on
         # _have_seen_events_dict
         raise NotImplementedError()
 
-    def _get_current_state_event_counts_txn(self, txn, room_id):
+    def _get_current_state_event_counts_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> int:
         """
         See get_current_state_event_counts.
         """
@@ -1324,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
             room_id,
         )
 
-    async def get_room_complexity(self, room_id):
+    async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
         """
         Get a rough approximation of the complexity of the room. This is used by
         remote servers to decide whether they wish to join the room or not.
@@ -1332,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
         more resources.
 
         Args:
-            room_id (str)
+            room_id: The room ID to query.
 
         Returns:
-            dict[str:int] of complexity version to complexity.
+            dict[str:float] of complexity version to complexity.
         """
         state_events = await self.get_current_state_event_counts(room_id)
 
@@ -1345,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {"v1": complexity_v1}
 
-    def get_current_events_token(self):
+    def get_current_events_token(self) -> int:
         """The current maximum token that events have reached"""
         return self._stream_id_gen.get_current_token()
 
     async def get_all_new_forward_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> List[Tuple]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -1365,7 +1403,9 @@ class EventsWorkerStore(SQLBaseStore):
             EventsStreamRow.
         """
 
-        def get_all_new_forward_event_rows(txn):
+        def get_all_new_forward_event_rows(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
                 " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1381,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, instance_name, limit))
-            return txn.fetchall()
+            return cast(
+                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1389,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_ex_outlier_stream_rows(
         self, instance_name: str, last_id: int, current_id: int
-    ) -> List[Tuple]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
@@ -1402,7 +1444,9 @@ class EventsWorkerStore(SQLBaseStore):
             EventsStreamRow.
         """
 
-        def get_ex_outlier_stream_rows_txn(txn):
+        def get_ex_outlier_stream_rows_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
                 " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1420,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
             )
 
             txn.execute(sql, (last_id, current_id, instance_name))
-            return txn.fetchall()
+            return cast(
+                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1428,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_new_backfill_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+    ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
         """Get updates for backfill replication stream, including all new
         backfilled events and events that have gone from being outliers to not.
 
@@ -1456,7 +1502,9 @@ class EventsWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_new_backfill_event_rows(txn):
+        def get_all_new_backfill_event_rows(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
             sql = (
                 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
                 " state_key, redacts, relates_to_id"
@@ -1470,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (-last_id, -current_id, instance_name, limit))
-            new_event_updates = [(row[0], row[1:]) for row in txn]
+            new_event_updates: List[
+                Tuple[int, Tuple[str, str, str, str, str, str]]
+            ] = []
+            row: Tuple[int, str, str, str, str, str, str]
+            # Type safety: iterating over `txn` yields `Tuple`, i.e.
+            # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+            # variadic tuple to a fixed length tuple and flags it up as an error.
+            for row in txn:  # type: ignore[assignment]
+                new_event_updates.append((row[0], row[1:]))
 
             limited = False
             if len(new_event_updates) == limit:
@@ -1493,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
                 " ORDER BY event_stream_ordering DESC"
             )
             txn.execute(sql, (-last_id, -upper_bound, instance_name))
-            new_event_updates.extend((row[0], row[1:]) for row in txn)
+            # Type safety: iterating over `txn` yields `Tuple`, i.e.
+            # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+            # variadic tuple to a fixed length tuple and flags it up as an error.
+            for row in txn:  # type: ignore[assignment]
+                new_event_updates.append((row[0], row[1:]))
 
             if len(new_event_updates) >= limit:
                 upper_bound = new_event_updates[-1][0]
@@ -1507,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_updated_current_state_deltas(
         self, instance_name: str, from_token: int, to_token: int, target_row_count: int
-    ) -> Tuple[List[Tuple], int, bool]:
+    ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
         """Fetch updates from current_state_delta_stream
 
         Args:
@@ -1527,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
                * `limited` is whether there are more updates to fetch.
         """
 
-        def get_all_updated_current_state_deltas_txn(txn):
+        def get_all_updated_current_state_deltas_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str]]:
             sql = """
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
@@ -1536,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
                 ORDER BY stream_id ASC LIMIT ?
             """
             txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
 
-        def get_deltas_for_stream_id_txn(txn, stream_id):
+        def get_deltas_for_stream_id_txn(
+            txn: LoggingTransaction, stream_id: int
+        ) -> List[Tuple[int, str, str, str, str]]:
             sql = """
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
                 WHERE stream_id = ?
             """
             txn.execute(sql, [stream_id])
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
 
         # we need to make sure that, for every stream id in the results, we get *all*
         # the rows with that stream id.
 
-        rows: List[Tuple] = await self.db_pool.runInteraction(
+        rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
         )
@@ -1579,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         return rows, to_token, True
 
-    async def is_event_after(self, event_id1, event_id2):
+    async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
         """Returns True if event_id1 is after event_id2 in the stream"""
         to_1, so_1 = await self.get_event_ordering(event_id1)
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
     @cached(max_entries=5000)
-    async def get_event_ordering(self, event_id):
+    async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
         res = await self.db_pool.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
@@ -1609,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
             None otherwise.
         """
 
-        def get_next_event_to_expire_txn(txn):
+        def get_next_event_to_expire_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Tuple[str, int]]:
             txn.execute(
                 """
                 SELECT event_id, expiry_ts FROM event_expiry
@@ -1617,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
                 """
             )
 
-            return txn.fetchone()
+            return cast(Optional[Tuple[str, int]], txn.fetchone())
 
         return await self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1681,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
         return mapping
 
     @wrap_as_background_process("_cleanup_old_transaction_ids")
-    async def _cleanup_old_transaction_ids(self):
+    async def _cleanup_old_transaction_ids(self) -> None:
         """Cleans out transaction id mappings older than 24hrs."""
 
-        def _cleanup_old_transaction_ids_txn(txn):
+        def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
             sql = """
                 DELETE FROM event_txn_id
                 WHERE inserted_ts < ?
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fa782023d4..3b63267395 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    StreamIdGenerator,
+)
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +85,9 @@ class PushRulesWorkerStore(
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen: Union[
-                StreamIdGenerator, SlavedIdTracker
-            ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
+            self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+                db_conn, "push_rules_stream", "stream_id"
+            )
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ac56bc9a05..4ff3013908 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -89,31 +89,77 @@ def _load_current_id(
     return (max if step > 0 else min)(current_id, step)
 
 
-class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
-    @abc.abstractmethod
-    def get_next(self) -> AsyncContextManager[int]:
-        raise NotImplementedError()
+class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
+    """Tracks the "current" stream ID of a stream that may have multiple writers.
+
+    Stream IDs are monotonically increasing or decreasing integers representing write
+    transactions. The "current" stream ID is the stream ID such that all transactions
+    with equal or smaller stream IDs have completed. Since transactions may complete out
+    of order, this is not the same as the stream ID of the last completed transaction.
+
+    Completed transactions include both committed transactions and transactions that
+    have been rolled back.
+    """
 
     @abc.abstractmethod
-    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+    def advance(self, instance_name: str, new_id: int) -> None:
+        """Advance the position of the named writer to the given ID, if greater
+        than existing entry.
+        """
         raise NotImplementedError()
 
     @abc.abstractmethod
     def get_current_token(self) -> int:
+        """Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
+
+        Returns:
+            The maximum stream id.
+        """
         raise NotImplementedError()
 
     @abc.abstractmethod
     def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to `get_current_token`.
+        """
+        raise NotImplementedError()
+
+
+class AbstractStreamIdGenerator(AbstractStreamIdTracker):
+    """Generates stream IDs for a stream that may have multiple writers.
+
+    Each stream ID represents a write transaction, whose completion is tracked
+    so that the "current" stream ID of the stream can be determined.
+
+    See `AbstractStreamIdTracker` for more details.
+    """
+
+    @abc.abstractmethod
+    def get_next(self) -> AsyncContextManager[int]:
+        """
+        Usage:
+            async with stream_id_gen.get_next() as stream_id:
+                # ... persist event ...
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+        """
+        Usage:
+            async with stream_id_gen.get_next(n) as stream_ids:
+                # ... persist events ...
+        """
         raise NotImplementedError()
 
 
 class StreamIdGenerator(AbstractStreamIdGenerator):
-    """Used to generate new stream ids when persisting events while keeping
-    track of which transactions have been completed.
+    """Generates and tracks stream IDs for a stream with a single writer.
 
-    This allows us to get the "current" stream id, i.e. the stream id such that
-    all ids less than or equal to it have completed. This handles the fact that
-    persistence of events can complete out of order.
+    This class must only be used when the current Synapse process is the sole
+    writer for a stream.
 
     Args:
         db_conn(connection):  A database connection to use to fetch the
@@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         # The key and values are the same, but we never look at the values.
         self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
 
+    def advance(self, instance_name: str, new_id: int) -> None:
+        # `StreamIdGenerator` should only be used when there is a single writer,
+        # so replication should never happen.
+        raise Exception("Replication is not supported by StreamIdGenerator")
+
     def get_next(self) -> AsyncContextManager[int]:
-        """
-        Usage:
-            async with stream_id_gen.get_next() as stream_id:
-                # ... persist event ...
-        """
         with self._lock:
             self._current += self._step
             next_id = self._current
@@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         return _AsyncCtxManagerWrapper(manager())
 
     def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
-        """
-        Usage:
-            async with stream_id_gen.get_next(n) as stream_ids:
-                # ... persist events ...
-        """
         with self._lock:
             next_ids = range(
                 self._current + self._step,
@@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self) -> int:
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-
-        Returns:
-            The maximum stream id.
-        """
         with self._lock:
             if self._unfinished_ids:
                 return next(iter(self._unfinished_ids)) - self._step
@@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
             return self._current
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer.
-
-        For streams with single writers this is equivalent to
-        `get_current_token`.
-        """
         return self.get_current_token()
 
 
 class MultiWriterIdGenerator(AbstractStreamIdGenerator):
-    """An ID generator that tracks a stream that can have multiple writers.
+    """Generates and tracks stream IDs for a stream with multiple writers.
 
     Uses a Postgres sequence to coordinate ID assignment, but positions of other
     writers will only get updated when `advance` is called (by replication).
@@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         return stream_ids
 
     def get_next(self) -> AsyncContextManager[int]:
-        """
-        Usage:
-            async with stream_id_gen.get_next() as stream_id:
-                # ... persist event ...
-        """
-
         # If we have a list of instances that are allowed to write to this
         # stream, make sure we're in it.
         if self._writers and self._instance_name not in self._writers:
@@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
 
     def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
-        """
-        Usage:
-            async with stream_id_gen.get_next_mult(5) as stream_ids:
-                # ... persist events ...
-        """
-
         # If we have a list of instances that are allowed to write to this
         # stream, make sure we're in it.
         if self._writers and self._instance_name not in self._writers:
@@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             self._add_persisted_position(next_id)
 
     def get_current_token(self) -> int:
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-        """
-
         return self.get_persisted_upto_position()
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer."""
-
         # If we don't have an entry for the given instance name, we assume it's a
         # new writer.
         #
@@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             }
 
     def advance(self, instance_name: str, new_id: int) -> None:
-        """Advance the position of the named writer to the given ID, if greater
-        than existing entry.
-        """
-
         new_id *= self._return_factor
 
         with self._lock:
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 0a6e4795ee..596ba5a0c9 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -17,6 +17,7 @@ from unittest.mock import patch
 from synapse.api.room_versions import RoomVersion
 from synapse.rest import admin
 from synapse.rest.client import login, room, sync
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import make_request
@@ -193,7 +194,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         #
         # Worker2's event stream position will not advance until we call
         # __aexit__ again.
-        actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
+        worker_store2 = worker_hs2.get_datastore()
+        assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator)
+
+        actx = worker_store2._stream_id_gen.get_next()
         self.get_success(actx.__aenter__())
 
         response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)