summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/background_updates.py19
-rw-r--r--synapse/storage/database.py44
-rw-r--r--synapse/storage/databases/main/cache.py8
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py4
-rw-r--r--synapse/storage/databases/main/events.py236
-rw-r--r--synapse/storage/databases/main/events_worker.py35
-rw-r--r--synapse/storage/databases/main/metrics.py24
-rw-r--r--synapse/storage/databases/main/purge_events.py3
-rw-r--r--synapse/storage/databases/main/relations.py6
-rw-r--r--synapse/storage/databases/main/search.py33
-rw-r--r--synapse/storage/databases/main/stream.py34
-rw-r--r--synapse/storage/engines/__init__.py12
-rw-r--r--synapse/storage/engines/_base.py26
-rw-r--r--synapse/storage/engines/postgres.py92
-rw-r--r--synapse/storage/engines/sqlite.py72
-rw-r--r--synapse/storage/persist_events.py63
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/__init__.py5
-rw-r--r--synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql18
-rw-r--r--synapse/storage/state.py6
-rw-r--r--synapse/storage/types.py80
21 files changed, 495 insertions, 327 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 08c6eabc6d..c2bbbb574e 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,20 +12,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from types import TracebackType
 from typing import (
     TYPE_CHECKING,
+    Any,
     AsyncContextManager,
     Awaitable,
     Callable,
     Dict,
     Iterable,
+    List,
     Optional,
+    Type,
 )
 
 import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.types import Connection
+from synapse.storage.types import Connection, Cursor
 from synapse.types import JsonDict
 from synapse.util import Clock, json_encoder
 
@@ -74,7 +78,12 @@ class _BackgroundUpdateContextManager:
 
         return self._update_duration_ms
 
-    async def __aexit__(self, *exc) -> None:
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
         pass
 
 
@@ -352,7 +361,7 @@ class BackgroundUpdater:
             True if we have finished running all the background updates, otherwise False
         """
 
-        def get_background_updates_txn(txn):
+        def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
             txn.execute(
                 """
                 SELECT update_name, depends_on FROM background_updates
@@ -469,7 +478,7 @@ class BackgroundUpdater:
         self,
         update_name: str,
         update_handler: Callable[[JsonDict, int], Awaitable[int]],
-    ):
+    ) -> None:
         """Register a handler for doing a background update.
 
         The handler should take two arguments:
@@ -603,7 +612,7 @@ class BackgroundUpdater:
         else:
             runner = create_index_sqlite
 
-        async def updater(progress, batch_size):
+        async def updater(progress: JsonDict, batch_size: int) -> int:
             if runner is not None:
                 logger.info("Adding index %s to %s", index_name, table)
                 await self.db_pool.runWithConnection(runner)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 41f566b648..5ddb58a8a2 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -31,6 +31,7 @@ from typing import (
     List,
     Optional,
     Tuple,
+    Type,
     TypeVar,
     cast,
     overload,
@@ -41,6 +42,7 @@ from prometheus_client import Histogram
 from typing_extensions import Concatenate, Literal, ParamSpec
 
 from twisted.enterprise import adbapi
+from twisted.internet.interfaces import IReactorCore
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 
 
 def make_pool(
-    reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    reactor: IReactorCore,
+    db_config: DatabaseConnectionConfig,
+    engine: BaseDatabaseEngine,
 ) -> adbapi.ConnectionPool:
     """Get the connection pool for the database."""
 
@@ -101,7 +105,7 @@ def make_pool(
     db_args = dict(db_config.config.get("args", {}))
     db_args.setdefault("cp_reconnect", True)
 
-    def _on_new_connection(conn):
+    def _on_new_connection(conn: Connection) -> None:
         # Ensure we have a logging context so we can correctly track queries,
         # etc.
         with LoggingContext("db.on_new_connection"):
@@ -157,7 +161,11 @@ class LoggingDatabaseConnection:
     default_txn_name: str
 
     def cursor(
-        self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+        self,
+        *,
+        txn_name: Optional[str] = None,
+        after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+        exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
     ) -> "LoggingTransaction":
         if not txn_name:
             txn_name = self.default_txn_name
@@ -183,11 +191,16 @@ class LoggingDatabaseConnection:
         self.conn.__enter__()
         return self
 
-    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> Optional[bool]:
         return self.conn.__exit__(exc_type, exc_value, traceback)
 
     # Proxy through any unknown lookups to the DB conn class.
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:
         return getattr(self.conn, name)
 
 
@@ -391,17 +404,22 @@ class LoggingTransaction:
     def __enter__(self) -> "LoggingTransaction":
         return self
 
-    def __exit__(self, exc_type, exc_value, traceback):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> None:
         self.close()
 
 
 class PerformanceCounters:
-    def __init__(self):
-        self.current_counters = {}
-        self.previous_counters = {}
+    def __init__(self) -> None:
+        self.current_counters: Dict[str, Tuple[int, float]] = {}
+        self.previous_counters: Dict[str, Tuple[int, float]] = {}
 
     def update(self, key: str, duration_secs: float) -> None:
-        count, cum_time = self.current_counters.get(key, (0, 0))
+        count, cum_time = self.current_counters.get(key, (0, 0.0))
         count += 1
         cum_time += duration_secs
         self.current_counters[key] = (count, cum_time)
@@ -527,7 +545,7 @@ class DatabasePool:
     def start_profiling(self) -> None:
         self._previous_loop_ts = monotonic_time()
 
-        def loop():
+        def loop() -> None:
             curr = self._current_txn_total_time
             prev = self._previous_txn_total_time
             self._previous_txn_total_time = curr
@@ -1186,7 +1204,7 @@ class DatabasePool:
         if lock:
             self.engine.lock_table(txn, table)
 
-        def _getwhere(key):
+        def _getwhere(key: str) -> str:
             # If the value we're passing in is None (aka NULL), we need to use
             # IS, not =, as NULL = NULL equals NULL (False).
             if keyvalues[key] is None:
@@ -2258,7 +2276,7 @@ class DatabasePool:
         term: Optional[str],
         col: str,
         retcols: Collection[str],
-        desc="simple_search_list",
+        desc: str = "simple_search_list",
     ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index dd4e83a2ad..1653a6a9b6 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -57,6 +57,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self._instance_name = hs.get_instance_name()
 
+        self.db_pool.updates.register_background_index_update(
+            update_name="cache_invalidation_index_by_instance",
+            index_name="cache_invalidation_stream_by_instance_instance_index",
+            table="cache_invalidation_stream_by_instance",
+            columns=("instance_name", "stream_id"),
+            psql_only=True,  # The table is only on postgres DBs.
+        )
+
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b789a588a5..af59be6b48 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -21,7 +21,7 @@ from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import LoggingTransaction
-from synapse.types import JsonDict, JsonSerializable
+from synapse.types import JsonDict, JsonSerializable, StreamKeyType
 from synapse.util import json_encoder
 
 
@@ -126,7 +126,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                     "message": "Set room key",
                     "room_id": room_id,
                     "session_id": session_id,
-                    "room_key": room_key,
+                    StreamKeyType.ROOM: room_key,
                 }
             )
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ed29a0a5e2..42d484dc98 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -36,9 +36,8 @@ from prometheus_client import Counter
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
-from synapse.crypto.event_signing import compute_event_reference_hash
-from synapse.events import EventBase  # noqa: F401
-from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.events import EventBase, relation_from_event
+from synapse.events.snapshot import EventContext
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -50,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines.postgres import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 
@@ -129,7 +128,6 @@ class PersistEventsStore:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         *,
-        current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremities: Dict[str, Set[str]],
         use_negative_stream_ordering: bool = False,
@@ -140,8 +138,6 @@ class PersistEventsStore:
 
         Args:
             events_and_contexts:
-            current_state_for_room: Map from room_id to the current state of
-                the room based on forward extremities
             state_delta_for_room: Map from room_id to the delta to apply to
                 room state
             new_forward_extremities: Map from room_id to set of event IDs
@@ -216,9 +212,6 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-            for room_id, new_state in current_state_for_room.items():
-                self.store.get_current_state_ids.prefill((room_id,), new_state)
-
             for room_id, latest_event_ids in new_forward_extremities.items():
                 self.store.get_latest_event_ids_in_room.prefill(
                     (room_id,), list(latest_event_ids)
@@ -236,7 +229,9 @@ class PersistEventsStore:
         """
         results: List[str] = []
 
-        def _get_events_which_are_prevs_txn(txn, batch):
+        def _get_events_which_are_prevs_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             sql = """
             SELECT prev_event_id, internal_metadata
             FROM event_edges
@@ -286,7 +281,9 @@ class PersistEventsStore:
         # and their prev events.
         existing_prevs = set()
 
-        def _get_prevs_before_rejected_txn(txn, batch):
+        def _get_prevs_before_rejected_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             to_recursively_check = batch
 
             while to_recursively_check:
@@ -516,7 +513,7 @@ class PersistEventsStore:
     @classmethod
     def _add_chain_cover_index(
         cls,
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -810,7 +807,7 @@ class PersistEventsStore:
 
     @staticmethod
     def _allocate_chain_ids(
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -944,7 +941,7 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
         to_insert = []
@@ -998,7 +995,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         state_delta_by_room: Dict[str, DeltaState],
         stream_id: int,
-    ):
+    ) -> None:
         for room_id, delta_state in state_delta_by_room.items():
             to_delete = delta_state.to_delete
             to_insert = delta_state.to_insert
@@ -1156,7 +1153,7 @@ class PersistEventsStore:
                 txn, room_id, members_changed
             )
 
-    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
         events.
 
@@ -1190,7 +1187,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         new_forward_extremities: Dict[str, Set[str]],
         max_stream_order: int,
-    ):
+    ) -> None:
         for room_id in new_forward_extremities.keys():
             self.db_pool.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@@ -1255,9 +1252,9 @@ class PersistEventsStore:
 
     def _update_room_depths_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Update min_depth for each room
 
         Args:
@@ -1386,7 +1383,7 @@ class PersistEventsStore:
             # nothing to do here
             return
 
-        def event_dict(event):
+        def event_dict(event: EventBase) -> JsonDict:
             d = event.get_dict()
             d.pop("redacted", None)
             d.pop("redacted_because", None)
@@ -1477,18 +1474,20 @@ class PersistEventsStore:
             ),
         )
 
-    def _store_rejected_events_txn(self, txn, events_and_contexts):
+    def _store_rejected_events_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Add rows to the 'rejections' table for received events which were
         rejected
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
+            txn: db connection
+            events_and_contexts: events we are persisting
 
         Returns:
-            list[(EventBase, EventContext)] new list, without the rejected
-                events.
+            new list, without the rejected events.
         """
         # Remove the rejected events from the list now that we've added them
         # to the events table and the events_json table.
@@ -1509,7 +1508,7 @@ class PersistEventsStore:
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         all_events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool = False,
-    ):
+    ) -> None:
         """Update all the miscellaneous tables for new events
 
         Args:
@@ -1600,15 +1599,14 @@ class PersistEventsStore:
             inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
-        # Insert event_reference_hashes table.
-        self._store_event_reference_hashes_txn(
-            txn, [event for event, _ in events_and_contexts]
-        )
-
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
-    def _add_to_cache(self, txn, events_and_contexts):
+    def _add_to_cache(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         to_prefill = []
 
         rows = []
@@ -1639,7 +1637,7 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        def prefill():
+        def prefill() -> None:
             for cache_entry in to_prefill:
                 self.store._get_event_cache.set(
                     (cache_entry.event.event_id,), cache_entry
@@ -1669,19 +1667,24 @@ class PersistEventsStore:
         )
 
     def insert_labels_for_event_txn(
-        self, txn, event_id, labels, room_id, topological_ordering
-    ):
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        labels: List[str],
+        room_id: str,
+        topological_ordering: int,
+    ) -> None:
         """Store the mapping between an event's ID and its labels, with one row per
         (event_id, label) tuple.
 
         Args:
-            txn (LoggingTransaction): The transaction to execute.
-            event_id (str): The event's ID.
-            labels (list[str]): A list of text labels.
-            room_id (str): The ID of the room the event was sent to.
-            topological_ordering (int): The position of the event in the room's topology.
+            txn: The transaction to execute.
+            event_id: The event's ID.
+            labels: A list of text labels.
+            room_id: The ID of the room the event was sent to.
+            topological_ordering: The position of the event in the room's topology.
         """
-        return self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_insert_many_txn(
             txn=txn,
             table="event_labels",
             keys=("event_id", "label", "room_id", "topological_ordering"),
@@ -1690,44 +1693,32 @@ class PersistEventsStore:
             ],
         )
 
-    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+    def _insert_event_expiry_txn(
+        self, txn: LoggingTransaction, event_id: str, expiry_ts: int
+    ) -> None:
         """Save the expiry timestamp associated with a given event ID.
 
         Args:
-            txn (LoggingTransaction): The database transaction to use.
-            event_id (str): The event ID the expiry timestamp is associated with.
-            expiry_ts (int): The timestamp at which to expire (delete) the event.
+            txn: The database transaction to use.
+            event_id: The event ID the expiry timestamp is associated with.
+            expiry_ts: The timestamp at which to expire (delete) the event.
         """
-        return self.db_pool.simple_insert_txn(
+        self.db_pool.simple_insert_txn(
             txn=txn,
             table="event_expiry",
             values={"event_id": event_id, "expiry_ts": expiry_ts},
         )
 
-    def _store_event_reference_hashes_txn(self, txn, events):
-        """Store a hash for a PDU
-        Args:
-            txn (cursor):
-            events (list): list of Events.
-        """
-
-        vals = []
-        for event in events:
-            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
-            vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
-
-        self.db_pool.simple_insert_many_txn(
-            txn,
-            table="event_reference_hashes",
-            keys=("event_id", "algorithm", "hash"),
-            values=vals,
-        )
-
     def _store_room_members_txn(
-        self, txn, events, *, inhibit_local_membership_updates: bool = False
-    ):
+        self,
+        txn: LoggingTransaction,
+        events: List[EventBase],
+        *,
+        inhibit_local_membership_updates: bool = False,
+    ) -> None:
         """
         Store a room member in the database.
+
         Args:
             txn: The transaction to use.
             events: List of events to store.
@@ -1767,6 +1758,7 @@ class PersistEventsStore:
         )
 
         for event in events:
+            assert event.internal_metadata.stream_ordering is not None
             txn.call_after(
                 self.store._membership_stream_cache.entity_has_changed,
                 event.state_key,
@@ -1815,55 +1807,50 @@ class PersistEventsStore:
             txn: The current database transaction.
             event: The event which might have relations.
         """
-        relation = event.content.get("m.relates_to")
+        relation = relation_from_event(event)
         if not relation:
-            # No relations
-            return
-
-        # Relations must have a type and parent event ID.
-        rel_type = relation.get("rel_type")
-        if not isinstance(rel_type, str):
+            # No relation, nothing to do.
             return
 
-        parent_id = relation.get("event_id")
-        if not isinstance(parent_id, str):
-            return
-
-        # Annotations have a key field.
-        aggregation_key = None
-        if rel_type == RelationTypes.ANNOTATION:
-            aggregation_key = relation.get("key")
-
         self.db_pool.simple_insert_txn(
             txn,
             table="event_relations",
             values={
                 "event_id": event.event_id,
-                "relates_to_id": parent_id,
-                "relation_type": rel_type,
-                "aggregation_key": aggregation_key,
+                "relates_to_id": relation.parent_id,
+                "relation_type": relation.rel_type,
+                "aggregation_key": relation.aggregation_key,
             },
         )
 
-        txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
         txn.call_after(
-            self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
+            self.store.get_relations_for_event.invalidate, (relation.parent_id,)
+        )
+        txn.call_after(
+            self.store.get_aggregation_groups_for_event.invalidate,
+            (relation.parent_id,),
         )
 
-        if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.REPLACE:
+            txn.call_after(
+                self.store.get_applicable_edit.invalidate, (relation.parent_id,)
+            )
 
-        if rel_type == RelationTypes.THREAD:
-            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.THREAD:
+            txn.call_after(
+                self.store.get_thread_summary.invalidate, (relation.parent_id,)
+            )
             # It should be safe to only invalidate the cache if the user has not
             # previously participated in the thread, but that's difficult (and
             # potentially error-prone) so it is always invalidated.
             txn.call_after(
                 self.store.get_thread_participated.invalidate,
-                (parent_id, event.sender),
+                (relation.parent_id, event.sender),
             )
 
-    def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_insertion_event(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         """Handles keeping track of insertion events and edges/connections.
         Part of MSC2716.
 
@@ -1924,7 +1911,7 @@ class PersistEventsStore:
                 },
             )
 
-    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
         """Handles inserting the batch edges/connections between the batch event
         and an insertion event. Part of MSC2716.
 
@@ -2024,25 +2011,29 @@ class PersistEventsStore:
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
 
-    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("topic"), str):
             self.store_event_search_txn(
                 txn, event, "content.topic", event.content["topic"]
             )
 
-    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("name"), str):
             self.store_event_search_txn(
                 txn, event, "content.name", event.content["name"]
             )
 
-    def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_message_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if isinstance(event.content.get("body"), str):
             self.store_event_search_txn(
                 txn, event, "content.body", event.content["body"]
             )
 
-    def _store_retention_policy_for_room_txn(self, txn, event):
+    def _store_retention_policy_for_room_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if not event.is_state():
             logger.debug("Ignoring non-state m.room.retention event")
             return
@@ -2102,8 +2093,11 @@ class PersistEventsStore:
         )
 
     def _set_push_actions_for_event_and_users_txn(
-        self, txn, events_and_contexts, all_events_and_contexts
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        all_events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         """Handles moving push actions from staging table to main
         event_push_actions table for all events in `events_and_contexts`.
 
@@ -2111,12 +2105,10 @@ class PersistEventsStore:
         from the push action staging area.
 
         Args:
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            all_events_and_contexts (list[(EventBase, EventContext)]): all
-                events that we were going to persist. This includes events
-                we've already persisted, etc, that wouldn't appear in
-                events_and_context.
+            events_and_contexts: events we are persisting
+            all_events_and_contexts: all events that we were going to persist.
+                This includes events we've already persisted, etc, that wouldn't
+                appear in events_and_context.
         """
 
         # Only non outlier events will have push actions associated with them,
@@ -2185,7 +2177,9 @@ class PersistEventsStore:
             ),
         )
 
-    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+    def _remove_push_actions_for_event_id_txn(
+        self, txn: LoggingTransaction, room_id: str, event_id: str
+    ) -> None:
         # Sad that we have to blow away the cache for the whole room here
         txn.call_after(
             self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@@ -2196,7 +2190,9 @@ class PersistEventsStore:
             (room_id, event_id),
         )
 
-    def _store_rejections_txn(self, txn, event_id, reason):
+    def _store_rejections_txn(
+        self, txn: LoggingTransaction, event_id: str, reason: str
+    ) -> None:
         self.db_pool.simple_insert_txn(
             txn,
             table="rejections",
@@ -2208,8 +2204,10 @@ class PersistEventsStore:
         )
 
     def _store_event_state_mappings_txn(
-        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+    ) -> None:
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
@@ -2266,7 +2264,9 @@ class PersistEventsStore:
                 state_group_id,
             )
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+    def _update_min_depth_for_room_txn(
+        self, txn: LoggingTransaction, room_id: str, depth: int
+    ) -> None:
         min_depth = self.store._get_min_depth_interaction(txn, room_id)
 
         if min_depth is not None and depth >= min_depth:
@@ -2279,7 +2279,9 @@ class PersistEventsStore:
             values={"min_depth": depth},
         )
 
-    def _handle_mult_prev_events(self, txn, events):
+    def _handle_mult_prev_events(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """
         For the given event, update the event edges table and forward and
         backward extremities tables.
@@ -2297,7 +2299,9 @@ class PersistEventsStore:
 
         self._update_backward_extremeties(txn, events)
 
-    def _update_backward_extremeties(self, txn, events):
+    def _update_backward_extremeties(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """Updates the event_backward_extremities tables based on the new/updated
         events being persisted.
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a4a604a499..5b22d6b452 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -14,6 +14,7 @@
 
 import logging
 import threading
+import weakref
 from enum import Enum, auto
 from typing import (
     TYPE_CHECKING,
@@ -23,6 +24,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    MutableMapping,
     Optional,
     Set,
     Tuple,
@@ -248,6 +250,12 @@ class EventsWorkerStore(SQLBaseStore):
             str, ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = {}
 
+        # We keep track of the events we have currently loaded in memory so that
+        # we can reuse them even if they've been evicted from the cache. We only
+        # track events that don't need redacting in here (as then we don't need
+        # to track redaction status).
+        self._event_ref: MutableMapping[str, EventBase] = weakref.WeakValueDictionary()
+
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list: List[
             Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
@@ -723,6 +731,8 @@ class EventsWorkerStore(SQLBaseStore):
 
     def _invalidate_get_event_cache(self, event_id: str) -> None:
         self._get_event_cache.invalidate((event_id,))
+        self._event_ref.pop(event_id, None)
+        self._current_event_fetches.pop(event_id, None)
 
     def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
@@ -738,13 +748,30 @@ class EventsWorkerStore(SQLBaseStore):
         event_map = {}
 
         for event_id in events:
+            # First check if it's in the event cache
             ret = self._get_event_cache.get(
                 (event_id,), None, update_metrics=update_metrics
             )
-            if not ret:
+            if ret:
+                event_map[event_id] = ret
                 continue
 
-            event_map[event_id] = ret
+            # Otherwise check if we still have the event in memory.
+            event = self._event_ref.get(event_id)
+            if event:
+                # Reconstruct an event cache entry
+
+                cache_entry = EventCacheEntry(
+                    event=event,
+                    # We don't cache weakrefs to redacted events, so we know
+                    # this is None.
+                    redacted_event=None,
+                )
+                event_map[event_id] = cache_entry
+
+                # We add the entry back into the cache as we want to keep
+                # recently queried events in the cache.
+                self._get_event_cache.set((event_id,), cache_entry)
 
         return event_map
 
@@ -1124,6 +1151,10 @@ class EventsWorkerStore(SQLBaseStore):
             self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
+            if not redacted_event:
+                # We only cache references to unredacted events.
+                self._event_ref[event_id] = original_ev
+
         return result_map
 
     async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 1480a0f048..d03555a585 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
+from synapse.storage.types import Cursor
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         self._last_user_visit_update = self._get_start_of_day()
 
     @wrap_as_background_process("read_forward_extremities")
-    async def _read_forward_extremities(self):
+    async def _read_forward_extremities(self) -> None:
         def fetch(txn):
             txn.execute(
                 """
@@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (x[0] - 1) * x[1] for x in res if x[1]
         )
 
-    async def count_daily_e2ee_messages(self):
+    async def count_daily_e2ee_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
 
-    async def count_daily_sent_e2ee_messages(self):
+    async def count_daily_sent_e2ee_messages(self) -> int:
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
@@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_sent_e2ee_messages", _count_messages
         )
 
-    async def count_daily_active_e2ee_rooms(self):
+    async def count_daily_active_e2ee_rooms(self) -> int:
         def _count(txn):
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
@@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_active_e2ee_rooms", _count
         )
 
-    async def count_daily_messages(self):
+    async def count_daily_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_messages", _count_messages)
 
-    async def count_daily_sent_messages(self):
+    async def count_daily_sent_messages(self) -> int:
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
@@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_sent_messages", _count_messages
         )
 
-    async def count_daily_active_rooms(self):
+    async def count_daily_active_rooms(self) -> int:
         def _count(txn):
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
@@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
-    def _count_users(self, txn, time_from):
+    def _count_users(self, txn: Cursor, time_from: int) -> int:
         """
         Returns number of users seen in the past time_from period
         """
@@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             ) u
         """
         txn.execute(sql, (time_from,))
-        (count,) = txn.fetchone()
+        # Mypy knows that fetchone() might return None if there are no rows.
+        # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
+        # returns exactly one row.
+        (count,) = txn.fetchone()  # type: ignore[misc]
         return count
 
     async def count_r30_users(self) -> Dict[str, int]:
@@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_r30v2_users", _count_r30v2_users
         )
 
-    def _get_start_of_day(self):
+    def _get_start_of_day(self) -> int:
         """
         Returns millisecond unixtime for start of UTC day.
         """
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index bfc85b3add..38ba91af4c 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,7 +69,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         #     event_forward_extremities
         #     event_json
         #     event_push_actions
-        #     event_reference_hashes
         #     event_relations
         #     event_search
         #     event_to_state_groups
@@ -220,7 +219,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_auth",
             "event_edges",
             "event_forward_extremities",
-            "event_reference_hashes",
             "event_relations",
             "event_search",
             "rejections",
@@ -369,7 +367,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_edges",
             "event_json",
             "event_push_actions_staging",
-            "event_reference_hashes",
             "event_relations",
             "event_to_state_groups",
             "event_auth_chains",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 484976ca6b..fe8fded88b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,7 +34,7 @@ from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 logger = logging.getLogger(__name__)
@@ -161,7 +161,9 @@ class RelationsWorkerStore(SQLBaseStore):
             if len(events) > limit and last_topo_id and last_stream_id:
                 next_key = RoomStreamToken(last_topo_id, last_stream_id)
                 if from_token:
-                    next_token = from_token.copy_and_replace("room_key", next_key)
+                    next_token = from_token.copy_and_replace(
+                        StreamKeyType.ROOM, next_key
+                    )
                 else:
                     next_token = StreamToken(
                         room_key=next_key,
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3c49e7ec98..78e0773b2a 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
 
 import attr
 
@@ -27,7 +27,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
             self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
         )
 
-    async def _background_reindex_search(self, progress, batch_size):
+    async def _background_reindex_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
@@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT stream_ordering, event_id, room_id, type, json, "
                 " origin_server_ts FROM events"
@@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         return result
 
-    async def _background_reindex_gin_search(self, progress, batch_size):
+    async def _background_reindex_gin_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """This handles old synapses which used GIST indexes, if any;
         converting them back to be GIN as per the actual schema.
         """
 
-        def create_index(conn):
+        def create_index(conn: LoggingDatabaseConnection) -> None:
             conn.rollback()
 
             # we have to set autocommit, because postgres refuses to
@@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
         )
         return 1
 
-    async def _background_reindex_search_order(self, progress, batch_size):
+    async def _background_reindex_search_order(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         if not have_added_index:
 
-            def create_index(conn):
+            def create_index(conn: LoggingDatabaseConnection) -> None:
                 conn.rollback()
                 conn.set_session(autocommit=True)
                 c = conn.cursor()
@@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
                 pg,
             )
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
             sql = (
                 "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
                 " origin_server_ts = e.origin_server_ts"
@@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
         else:
             raise Exception("Unrecognized database engine")
 
-        args.append(limit)
+        # mypy expects to append only a `str`, not an `int`
+        args.append(limit)  # type: ignore[arg-type]
 
         results = await self.db_pool.execute(
             "search_rooms", self.db_pool.cursor_to_dict, sql, *args
@@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
             A set of strings.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> Set[str]:
             highlight_words = set()
             for event in events:
                 # As a hack we simply join values of all possible keys. This is
@@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
         return await self.db_pool.runInteraction("_find_highlights", f)
 
 
-def _to_postgres_options(options_dict):
+def _to_postgres_options(options_dict: JsonDict) -> str:
     return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
 
 
-def _parse_query(database_engine, search_term):
+def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
     """Takes a plain unicode string from the user and converts it into a form
     that can be passed to database.
     We use this so that we can add prefix matching, which isn't something
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 0373af86c8..0e3a23a140 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -788,30 +788,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return None
 
     async def get_current_room_stream_token_for_room_id(
-        self, room_id: Optional[str] = None
+        self, room_id: str
     ) -> RoomStreamToken:
-        """Returns the current position of the rooms stream.
-
-        By default, it returns a live token with the current global stream
-        token. Specifying a `room_id` causes it to return a historic token with
-        the room specific topological token.
-        """
+        """Returns the current position of the rooms stream (historic token)."""
         stream_ordering = self.get_room_max_stream_ordering()
-        if room_id is None:
-            return RoomStreamToken(None, stream_ordering)
-        else:
-            topo = await self.db_pool.runInteraction(
-                "_get_max_topological_txn", self._get_max_topological_txn, room_id
-            )
-            return RoomStreamToken(topo, stream_ordering)
+        topo = await self.db_pool.runInteraction(
+            "_get_max_topological_txn", self._get_max_topological_txn, room_id
+        )
+        return RoomStreamToken(topo, stream_ordering)
 
     def get_stream_id_for_event_txn(
         self,
         txn: LoggingTransaction,
         event_id: str,
-        allow_none=False,
-    ) -> int:
-        return self.db_pool.simple_select_one_onecol_txn(
+        allow_none: bool = False,
+    ) -> Optional[int]:
+        # Type ignore: we pass keyvalues a Dict[str, str]; the function wants
+        # Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
+        return self.db_pool.simple_select_one_onecol_txn(  # type: ignore[call-overload]
             txn=txn,
             table="events",
             keyvalues={"event_id": event_id},
@@ -873,7 +867,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         rows = txn.fetchall()
-        return rows[0][0] if rows else 0
+        # An aggregate function like MAX() will always return one row per group
+        # so we can safely rely on the lookup here. For example, when a we
+        # lookup a `room_id` which does not exist, `rows` will look like
+        # `[(None,)]`
+        return rows[0][0] if rows[0][0] is not None else 0
 
     @staticmethod
     def _set_before_and_after(
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index afb7d5054d..f51b3d228e 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -11,25 +11,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Mapping
 
 from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 from .postgres import PostgresEngine
 from .sqlite import Sqlite3Engine
 
 
-def create_engine(database_config) -> BaseDatabaseEngine:
+def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:
     name = database_config["name"]
 
     if name == "sqlite3":
-        import sqlite3
-
-        return Sqlite3Engine(sqlite3, database_config)
+        return Sqlite3Engine(database_config)
 
     if name == "psycopg2":
-        # Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
-        import psycopg2
-
-        return PostgresEngine(psycopg2, database_config)
+        return PostgresEngine(database_config)
 
     raise RuntimeError("Unsupported database engine '%s'" % (name,))
 
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 143cd98ca2..971ff82693 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -13,9 +13,12 @@
 # limitations under the License.
 import abc
 from enum import IntEnum
-from typing import Generic, Optional, TypeVar
+from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, TypeVar
 
-from synapse.storage.types import Connection
+from synapse.storage.types import Connection, Cursor, DBAPI2Module
+
+if TYPE_CHECKING:
+    from synapse.storage.database import LoggingDatabaseConnection
 
 
 class IsolationLevel(IntEnum):
@@ -32,7 +35,7 @@ ConnectionType = TypeVar("ConnectionType", bound=Connection)
 
 
 class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
-    def __init__(self, module, database_config: dict):
+    def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]):
         self.module = module
 
     @property
@@ -69,7 +72,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def check_new_database(self, txn) -> None:
+    def check_new_database(self, txn: Cursor) -> None:
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
@@ -79,8 +82,11 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
     def convert_param_style(self, sql: str) -> str:
         ...
 
+    # This method would ideally take a plain ConnectionType, but it seems that
+    # the Sqlite engine expects to use LoggingDatabaseConnection.cursor
+    # instead of sqlite3.Connection.cursor: only the former takes a txn_name.
     @abc.abstractmethod
-    def on_new_connection(self, db_conn: ConnectionType) -> None:
+    def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
         ...
 
     @abc.abstractmethod
@@ -92,7 +98,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def lock_table(self, txn, table: str) -> None:
+    def lock_table(self, txn: Cursor, table: str) -> None:
         ...
 
     @property
@@ -102,12 +108,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def in_transaction(self, conn: Connection) -> bool:
+    def in_transaction(self, conn: ConnectionType) -> bool:
         """Whether the connection is currently in a transaction."""
         ...
 
     @abc.abstractmethod
-    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+    def attempt_to_set_autocommit(self, conn: ConnectionType, autocommit: bool) -> None:
         """Attempt to set the connections autocommit mode.
 
         When True queries are run outside of transactions.
@@ -119,8 +125,8 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
     def attempt_to_set_isolation_level(
-        self, conn: Connection, isolation_level: Optional[int]
-    ):
+        self, conn: ConnectionType, isolation_level: Optional[int]
+    ) -> None:
         """Attempt to set the connections isolation level.
 
         Note: This has no effect on SQLite3, as transactions are SERIALIZABLE by default.
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index e8d29e2870..391f8ed24a 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,39 +13,47 @@
 # limitations under the License.
 
 import logging
-from typing import Mapping, Optional
+from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
 
 from synapse.storage.engines._base import (
     BaseDatabaseEngine,
     IncorrectDatabaseSetup,
     IsolationLevel,
 )
-from synapse.storage.types import Connection
+from synapse.storage.types import Cursor
+
+if TYPE_CHECKING:
+    import psycopg2  # noqa: F401
+
+    from synapse.storage.database import LoggingDatabaseConnection
+
 
 logger = logging.getLogger(__name__)
 
 
-class PostgresEngine(BaseDatabaseEngine):
-    def __init__(self, database_module, database_config):
-        super().__init__(database_module, database_config)
-        self.module.extensions.register_type(self.module.extensions.UNICODE)
+class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]):
+    def __init__(self, database_config: Mapping[str, Any]):
+        import psycopg2.extensions
+
+        super().__init__(psycopg2, database_config)
+        psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
 
         # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
         # actually want to use bytes than wrap it in `bytearray`.
-        def _disable_bytes_adapter(_):
+        def _disable_bytes_adapter(_: bytes) -> NoReturn:
             raise Exception("Passing bytes to DB is disabled.")
 
-        self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
-        self.synchronous_commit = database_config.get("synchronous_commit", True)
-        self._version = None  # unknown as yet
+        psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
+        self.synchronous_commit: bool = database_config.get("synchronous_commit", True)
+        self._version: Optional[int] = None  # unknown as yet
 
         self.isolation_level_map: Mapping[int, int] = {
-            IsolationLevel.READ_COMMITTED: self.module.extensions.ISOLATION_LEVEL_READ_COMMITTED,
-            IsolationLevel.REPEATABLE_READ: self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
-            IsolationLevel.SERIALIZABLE: self.module.extensions.ISOLATION_LEVEL_SERIALIZABLE,
+            IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
+            IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
+            IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
         }
         self.default_isolation_level = (
-            self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
+            psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
         self.config = database_config
 
@@ -53,19 +61,21 @@ class PostgresEngine(BaseDatabaseEngine):
     def single_threaded(self) -> bool:
         return False
 
-    def get_db_locale(self, txn):
+    def get_db_locale(self, txn: Cursor) -> Tuple[str, str]:
         txn.execute(
             "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
         )
-        collation, ctype = txn.fetchone()
+        collation, ctype = cast(Tuple[str, str], txn.fetchone())
         return collation, ctype
 
-    def check_database(self, db_conn, allow_outdated_version: bool = False):
+    def check_database(
+        self, db_conn: "psycopg2.connection", allow_outdated_version: bool = False
+    ) -> None:
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
         # revision numbers into two-decimal-digit numbers and appending them
         # together. For example, version 8.1.5 will be returned as 80105
-        self._version = db_conn.server_version
+        self._version = cast(int, db_conn.server_version)
         allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
 
         # Are we on a supported PostgreSQL version?
@@ -108,7 +118,7 @@ class PostgresEngine(BaseDatabaseEngine):
                         ctype,
                     )
 
-    def check_new_database(self, txn):
+    def check_new_database(self, txn: Cursor) -> None:
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
@@ -129,10 +139,10 @@ class PostgresEngine(BaseDatabaseEngine):
                 "See docs/postgres.md for more information." % ("\n".join(errors))
             )
 
-    def convert_param_style(self, sql):
+    def convert_param_style(self, sql: str) -> str:
         return sql.replace("?", "%s")
 
-    def on_new_connection(self, db_conn):
+    def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
         db_conn.set_isolation_level(self.default_isolation_level)
 
         # Set the bytea output to escape, vs the default of hex
@@ -149,14 +159,14 @@ class PostgresEngine(BaseDatabaseEngine):
         db_conn.commit()
 
     @property
-    def can_native_upsert(self):
+    def can_native_upsert(self) -> bool:
         """
         Can we use native UPSERTs?
         """
         return True
 
     @property
-    def supports_using_any_list(self):
+    def supports_using_any_list(self) -> bool:
         """Do we support using `a = ANY(?)` and passing a list"""
         return True
 
@@ -165,27 +175,25 @@ class PostgresEngine(BaseDatabaseEngine):
         """Do we support the `RETURNING` clause in insert/update/delete?"""
         return True
 
-    def is_deadlock(self, error):
-        if isinstance(error, self.module.DatabaseError):
+    def is_deadlock(self, error: Exception) -> bool:
+        import psycopg2.extensions
+
+        if isinstance(error, psycopg2.DatabaseError):
             # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
             # "40001" serialization_failure
             # "40P01" deadlock_detected
             return error.pgcode in ["40001", "40P01"]
         return False
 
-    def is_connection_closed(self, conn):
+    def is_connection_closed(self, conn: "psycopg2.connection") -> bool:
         return bool(conn.closed)
 
-    def lock_table(self, txn, table):
+    def lock_table(self, txn: Cursor, table: str) -> None:
         txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
 
     @property
-    def server_version(self):
-        """Returns a string giving the server version. For example: '8.1.5'
-
-        Returns:
-            string
-        """
+    def server_version(self) -> str:
+        """Returns a string giving the server version. For example: '8.1.5'."""
         # note that this is a bit of a hack because it relies on check_database
         # having been called. Still, that should be a safe bet here.
         numver = self._version
@@ -197,17 +205,21 @@ class PostgresEngine(BaseDatabaseEngine):
         else:
             return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
 
-    def in_transaction(self, conn: Connection) -> bool:
-        return conn.status != self.module.extensions.STATUS_READY  # type: ignore
+    def in_transaction(self, conn: "psycopg2.connection") -> bool:
+        import psycopg2.extensions
+
+        return conn.status != psycopg2.extensions.STATUS_READY
 
-    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
-        return conn.set_session(autocommit=autocommit)  # type: ignore
+    def attempt_to_set_autocommit(
+        self, conn: "psycopg2.connection", autocommit: bool
+    ) -> None:
+        return conn.set_session(autocommit=autocommit)
 
     def attempt_to_set_isolation_level(
-        self, conn: Connection, isolation_level: Optional[int]
-    ):
+        self, conn: "psycopg2.connection", isolation_level: Optional[int]
+    ) -> None:
         if isolation_level is None:
             isolation_level = self.default_isolation_level
         else:
             isolation_level = self.isolation_level_map[isolation_level]
-        return conn.set_isolation_level(isolation_level)  # type: ignore
+        return conn.set_isolation_level(isolation_level)
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 6c19e55999..621f2c5efe 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,21 +12,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import platform
+import sqlite3
 import struct
 import threading
-import typing
-from typing import Optional
+from typing import TYPE_CHECKING, Any, List, Mapping, Optional
 
 from synapse.storage.engines import BaseDatabaseEngine
-from synapse.storage.types import Connection
+from synapse.storage.types import Cursor
 
-if typing.TYPE_CHECKING:
-    import sqlite3  # noqa: F401
+if TYPE_CHECKING:
+    from synapse.storage.database import LoggingDatabaseConnection
 
 
-class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
-    def __init__(self, database_module, database_config):
-        super().__init__(database_module, database_config)
+class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
+    def __init__(self, database_config: Mapping[str, Any]):
+        super().__init__(sqlite3, database_config)
 
         database = database_config.get("args", {}).get("database")
         self._is_in_memory = database in (
@@ -37,7 +37,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         if platform.python_implementation() == "PyPy":
             # pypy's sqlite3 module doesn't handle bytearrays, convert them
             # back to bytes.
-            database_module.register_adapter(bytearray, lambda array: bytes(array))
+            sqlite3.register_adapter(bytearray, lambda array: bytes(array))
 
         # The current max state_group, or None if we haven't looked
         # in the DB yet.
@@ -49,41 +49,43 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         return True
 
     @property
-    def can_native_upsert(self):
+    def can_native_upsert(self) -> bool:
         """
         Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
         more work we haven't done yet to tell what was inserted vs updated.
         """
-        return self.module.sqlite_version_info >= (3, 24, 0)
+        return sqlite3.sqlite_version_info >= (3, 24, 0)
 
     @property
-    def supports_using_any_list(self):
+    def supports_using_any_list(self) -> bool:
         """Do we support using `a = ANY(?)` and passing a list"""
         return False
 
     @property
     def supports_returning(self) -> bool:
         """Do we support the `RETURNING` clause in insert/update/delete?"""
-        return self.module.sqlite_version_info >= (3, 35, 0)
+        return sqlite3.sqlite_version_info >= (3, 35, 0)
 
-    def check_database(self, db_conn, allow_outdated_version: bool = False):
+    def check_database(
+        self, db_conn: sqlite3.Connection, allow_outdated_version: bool = False
+    ) -> None:
         if not allow_outdated_version:
-            version = self.module.sqlite_version_info
+            version = sqlite3.sqlite_version_info
             # Synapse is untested against older SQLite versions, and we don't want
             # to let users upgrade to a version of Synapse with broken support for their
             # sqlite version, because it risks leaving them with a half-upgraded db.
             if version < (3, 22, 0):
                 raise RuntimeError("Synapse requires sqlite 3.22 or above.")
 
-    def check_new_database(self, txn):
+    def check_new_database(self, txn: Cursor) -> None:
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
 
-    def convert_param_style(self, sql):
+    def convert_param_style(self, sql: str) -> str:
         return sql
 
-    def on_new_connection(self, db_conn):
+    def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
         # We need to import here to avoid an import loop.
         from synapse.storage.prepare_database import prepare_database
 
@@ -97,48 +99,46 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         db_conn.execute("PRAGMA foreign_keys = ON;")
         db_conn.commit()
 
-    def is_deadlock(self, error):
+    def is_deadlock(self, error: Exception) -> bool:
         return False
 
-    def is_connection_closed(self, conn):
+    def is_connection_closed(self, conn: sqlite3.Connection) -> bool:
         return False
 
-    def lock_table(self, txn, table):
+    def lock_table(self, txn: Cursor, table: str) -> None:
         return
 
     @property
-    def server_version(self):
-        """Gets a string giving the server version. For example: '3.22.0'
+    def server_version(self) -> str:
+        """Gets a string giving the server version. For example: '3.22.0'."""
+        return "%i.%i.%i" % sqlite3.sqlite_version_info
 
-        Returns:
-            string
-        """
-        return "%i.%i.%i" % self.module.sqlite_version_info
-
-    def in_transaction(self, conn: Connection) -> bool:
-        return conn.in_transaction  # type: ignore
+    def in_transaction(self, conn: sqlite3.Connection) -> bool:
+        return conn.in_transaction
 
-    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+    def attempt_to_set_autocommit(
+        self, conn: sqlite3.Connection, autocommit: bool
+    ) -> None:
         # Twisted doesn't let us set attributes on the connections, so we can't
         # set the connection to autocommit mode.
         pass
 
     def attempt_to_set_isolation_level(
-        self, conn: Connection, isolation_level: Optional[int]
-    ):
-        # All transactions are SERIALIZABLE by default in sqllite
+        self, conn: sqlite3.Connection, isolation_level: Optional[int]
+    ) -> None:
+        # All transactions are SERIALIZABLE by default in sqlite
         pass
 
 
 # Following functions taken from: https://github.com/coleifer/peewee
 
 
-def _parse_match_info(buf):
+def _parse_match_info(buf: bytes) -> List[int]:
     bufsize = len(buf)
     return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
 
 
-def _rank(raw_match_info):
+def _rank(raw_match_info: bytes) -> float:
     """Handle match_info called w/default args 'pcx' - based on the example rank
     function http://sqlite.org/fts3.html#appendix_a
     """
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 97118045a1..0fc282866b 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -25,6 +25,7 @@ from typing import (
     Collection,
     Deque,
     Dict,
+    Generator,
     Generic,
     Iterable,
     List,
@@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
 
         return res
 
-    def _handle_queue(self, room_id):
+    def _handle_queue(self, room_id: str) -> None:
         """Attempts to handle the queue for a room if not already being handled.
 
         The queue's callback will be invoked with for each item in the queue,
@@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
 
         self._currently_persisting_rooms.add(room_id)
 
-        async def handle_queue_loop():
+        async def handle_queue_loop() -> None:
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
@@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                         with PreserveLoggingContext():
                             item.deferred.callback(ret)
             finally:
-                queue = self._event_persist_queues.pop(room_id, None)
-                if queue:
-                    self._event_persist_queues[room_id] = queue
+                remaining_queue = self._event_persist_queues.pop(room_id, None)
+                if remaining_queue:
+                    self._event_persist_queues[room_id] = remaining_queue
                 self._currently_persisting_rooms.discard(room_id)
 
         # set handle_queue_loop off in the background
         run_as_background_process("persist_events", handle_queue_loop)
 
-    def _get_drainining_queue(self, room_id):
+    def _get_drainining_queue(
+        self, room_id: str
+    ) -> Generator[_EventPersistQueueItem, None, None]:
         queue = self._event_persist_queues.setdefault(room_id, deque())
 
         try:
@@ -317,7 +320,9 @@ class EventsPersistenceStorage:
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
-        async def enqueue(item):
+        async def enqueue(
+            item: Tuple[str, List[Tuple[EventBase, EventContext]]]
+        ) -> Dict[str, str]:
             room_id, evs_ctxs = item
             return await self._event_persist_queue.add_to_queue(
                 room_id, evs_ctxs, backfilled=backfilled
@@ -487,12 +492,6 @@ class EventsPersistenceStorage:
             # extremities in each room
             new_forward_extremities: Dict[str, Set[str]] = {}
 
-            # map room_id->(type,state_key)->event_id tracking the full
-            # state in each room after adding these events.
-            # This is simply used to prefill the get_current_state_ids
-            # cache
-            current_state_for_room: Dict[str, StateMap[str]] = {}
-
             # map room_id->(to_delete, to_insert) where to_delete is a list
             # of type/state keys to remove from current state, and to_insert
             # is a map (type,key)->event_id giving the state delta in each
@@ -628,14 +627,8 @@ class EventsPersistenceStorage:
 
                             state_delta_for_room[room_id] = delta
 
-                        # If we have the current_state then lets prefill
-                        # the cache with it.
-                        if current_state is not None:
-                            current_state_for_room[room_id] = current_state
-
             await self.persist_events_store._persist_events_and_state_updates(
                 chunk,
-                current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
                 use_negative_stream_ordering=backfilled,
@@ -733,7 +726,8 @@ class EventsPersistenceStorage:
 
             The first state map is the full new current state and the second
             is the delta to the existing current state. If both are None then
-            there has been no change.
+            there has been no change. Either or neither can be None if there
+            has been a change.
 
             The function may prune some old entries from the set of new
             forward extremities if it's safe to do so.
@@ -743,9 +737,6 @@ class EventsPersistenceStorage:
             the new current state is only returned if we've already calculated
             it.
         """
-        # map from state_group to ((type, key) -> event_id) state map
-        state_groups_map = {}
-
         # Map from (prev state group, new state group) -> delta state dict
         state_group_deltas = {}
 
@@ -759,16 +750,6 @@ class EventsPersistenceStorage:
                     )
                 continue
 
-            if ctx.state_group in state_groups_map:
-                continue
-
-            # We're only interested in pulling out state that has already
-            # been cached in the context. We'll pull stuff out of the DB later
-            # if necessary.
-            current_state_ids = ctx.get_cached_current_state_ids()
-            if current_state_ids is not None:
-                state_groups_map[ctx.state_group] = current_state_ids
-
             if ctx.prev_group:
                 state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
 
@@ -826,18 +807,14 @@ class EventsPersistenceStorage:
             delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
             if delta_ids is not None:
                 # We have a delta from the existing to new current state,
-                # so lets just return that. If we happen to already have
-                # the current state in memory then lets also return that,
-                # but it doesn't matter if we don't.
-                new_state = state_groups_map.get(new_state_group)
-                return new_state, delta_ids, new_latest_event_ids
+                # so lets just return that.
+                return None, delta_ids, new_latest_event_ids
 
         # Now that we have calculated new_state_groups we need to get
         # their state IDs so we can resolve to a single state set.
-        missing_state = new_state_groups - set(state_groups_map)
-        if missing_state:
-            group_to_state = await self.state_store._get_state_for_groups(missing_state)
-            state_groups_map.update(group_to_state)
+        state_groups_map = await self.state_store._get_state_for_groups(
+            new_state_groups
+        )
 
         if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
@@ -1130,7 +1107,7 @@ class EventsPersistenceStorage:
 
         return False
 
-    async def _handle_potentially_left_users(self, user_ids: Set[str]):
+    async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
         """Given a set of remote users check if the server still shares a room with
         them. If not then mark those users' device cache as stale.
         """
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 546d6bae6e..c33df42084 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -85,7 +85,7 @@ def prepare_database(
     database_engine: BaseDatabaseEngine,
     config: Optional[HomeServerConfig],
     databases: Collection[str] = ("main", "state"),
-):
+) -> None:
     """Prepares a physical database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
 
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 871d4ace12..20c344faea 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 69  # remember to update the list below when updating
+SCHEMA_VERSION = 70  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -62,6 +62,9 @@ Changes in SCHEMA_VERSION = 68:
 Changes in SCHEMA_VERSION = 69:
     - We now write to `device_lists_changes_in_room` table.
     - Use sequence to generate future `application_services_txns.txn_id`s
+
+Changes in SCHEMA_VERSION = 70:
+    - event_reference_hashes is no longer written to.
 """
 
 
diff --git a/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql b/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql
new file mode 100644
index 0000000000..22ae3b8c00
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql
@@ -0,0 +1,18 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Background update to clear the inboxes of hidden and deleted devices.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (6902, 'cache_invalidation_index_by_instance', '{}');
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index d1d5859214..d4a1bd4f9d 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -62,7 +62,7 @@ class StateFilter:
     types: "frozendict[str, Optional[FrozenSet[str]]]"
     include_others: bool = False
 
-    def __attrs_post_init__(self):
+    def __attrs_post_init__(self) -> None:
         # If `include_others` is set we canonicalise the filter by removing
         # wildcards from the types dictionary
         if self.include_others:
@@ -138,7 +138,9 @@ class StateFilter:
         )
 
     @staticmethod
-    def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
+    def freeze(
+        types: Mapping[str, Optional[Collection[str]]], include_others: bool
+    ) -> "StateFilter":
         """
         Returns a (frozen) StateFilter with the same contents as the parameters
         specified here, which can be made of mutable types.
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index d7d6f1d90e..0031df1e06 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -11,7 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
+from types import TracebackType
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
 
 from typing_extensions import Protocol
 
@@ -86,5 +87,80 @@ class Connection(Protocol):
     def __enter__(self) -> "Connection":
         ...
 
-    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> Optional[bool]:
+        ...
+
+
+class DBAPI2Module(Protocol):
+    """The module-level attributes that we use from PEP 249.
+
+    This is NOT a comprehensive stub for the entire DBAPI2."""
+
+    __name__: str
+
+    # Exceptions. See https://peps.python.org/pep-0249/#exceptions
+
+    # For our specific drivers:
+    # - Python's sqlite3 module doesn't contains the same descriptions as the
+    #   DBAPI2 spec, see https://docs.python.org/3/library/sqlite3.html#exceptions
+    # - Psycopg2 maps every Postgres error code onto a unique exception class which
+    #   extends from this hierarchy. See
+    #     https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions
+    #     https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE
+    Warning: Type[Exception]
+    Error: Type[Exception]
+
+    # Errors are divided into `InterfaceError`s (something went wrong in the database
+    # driver) and `DatabaseError`s (something went wrong in the database). These are
+    # both subclasses of `Error`, but we can't currently express this in type
+    # annotations due to https://github.com/python/mypy/issues/8397
+    InterfaceError: Type[Exception]
+    DatabaseError: Type[Exception]
+
+    # Everything below is a subclass of `DatabaseError`.
+
+    # Roughly: the database rejected a nonsensical value. Examples:
+    # - An integer was too big for its data type.
+    # - An invalid date time was provided.
+    # - A string contained a null code point.
+    DataError: Type[Exception]
+
+    # Roughly: something went wrong in the database, but it's not within the application
+    # programmer's control. Examples:
+    # - We failed to establish a connection to the database.
+    # - The connection to the database was lost.
+    # - A deadlock was detected.
+    # - A serialisation failure occurred.
+    # - The database ran out of resources, such as storage, memory, connections, etc.
+    # - The database encountered an error from the operating system.
+    OperationalError: Type[Exception]
+
+    # Roughly: we've given the database data which breaks a rule we asked it to enforce.
+    # Examples:
+    # - Stop, criminal scum! You violated the foreign key constraint
+    # - Also check constraints, non-null constraints, etc.
+    IntegrityError: Type[Exception]
+
+    # Roughly: something went wrong within the database server itself.
+    InternalError: Type[Exception]
+
+    # Roughly: the application did something silly that needs to be fixed. Examples:
+    # - We don't have permissions to do something.
+    # - We tried to create a table with duplicate column names.
+    # - We tried to use a reserved name.
+    # - We referred to a column that doesn't exist.
+    ProgrammingError: Type[Exception]
+
+    # Roughly: we've tried to do something that this database doesn't support.
+    NotSupportedError: Type[Exception]
+
+    def connect(self, **parameters: object) -> Connection:
         ...
+
+
+__all__ = ["Cursor", "Connection", "DBAPI2Module"]