summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/background_updates.py19
-rw-r--r--synapse/storage/database.py44
-rw-r--r--synapse/storage/databases/main/events.py236
-rw-r--r--synapse/storage/databases/main/metrics.py24
-rw-r--r--synapse/storage/databases/main/purge_events.py3
-rw-r--r--synapse/storage/databases/main/search.py33
-rw-r--r--synapse/storage/databases/main/stream.py34
-rw-r--r--synapse/storage/persist_events.py63
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/__init__.py5
-rw-r--r--synapse/storage/state.py6
-rw-r--r--synapse/storage/types.py10
12 files changed, 252 insertions, 227 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 08c6eabc6d..c2bbbb574e 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,20 +12,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from types import TracebackType
 from typing import (
     TYPE_CHECKING,
+    Any,
     AsyncContextManager,
     Awaitable,
     Callable,
     Dict,
     Iterable,
+    List,
     Optional,
+    Type,
 )
 
 import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.types import Connection
+from synapse.storage.types import Connection, Cursor
 from synapse.types import JsonDict
 from synapse.util import Clock, json_encoder
 
@@ -74,7 +78,12 @@ class _BackgroundUpdateContextManager:
 
         return self._update_duration_ms
 
-    async def __aexit__(self, *exc) -> None:
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
         pass
 
 
@@ -352,7 +361,7 @@ class BackgroundUpdater:
             True if we have finished running all the background updates, otherwise False
         """
 
-        def get_background_updates_txn(txn):
+        def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
             txn.execute(
                 """
                 SELECT update_name, depends_on FROM background_updates
@@ -469,7 +478,7 @@ class BackgroundUpdater:
         self,
         update_name: str,
         update_handler: Callable[[JsonDict, int], Awaitable[int]],
-    ):
+    ) -> None:
         """Register a handler for doing a background update.
 
         The handler should take two arguments:
@@ -603,7 +612,7 @@ class BackgroundUpdater:
         else:
             runner = create_index_sqlite
 
-        async def updater(progress, batch_size):
+        async def updater(progress: JsonDict, batch_size: int) -> int:
             if runner is not None:
                 logger.info("Adding index %s to %s", index_name, table)
                 await self.db_pool.runWithConnection(runner)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 41f566b648..5ddb58a8a2 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -31,6 +31,7 @@ from typing import (
     List,
     Optional,
     Tuple,
+    Type,
     TypeVar,
     cast,
     overload,
@@ -41,6 +42,7 @@ from prometheus_client import Histogram
 from typing_extensions import Concatenate, Literal, ParamSpec
 
 from twisted.enterprise import adbapi
+from twisted.internet.interfaces import IReactorCore
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 
 
 def make_pool(
-    reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    reactor: IReactorCore,
+    db_config: DatabaseConnectionConfig,
+    engine: BaseDatabaseEngine,
 ) -> adbapi.ConnectionPool:
     """Get the connection pool for the database."""
 
@@ -101,7 +105,7 @@ def make_pool(
     db_args = dict(db_config.config.get("args", {}))
     db_args.setdefault("cp_reconnect", True)
 
-    def _on_new_connection(conn):
+    def _on_new_connection(conn: Connection) -> None:
         # Ensure we have a logging context so we can correctly track queries,
         # etc.
         with LoggingContext("db.on_new_connection"):
@@ -157,7 +161,11 @@ class LoggingDatabaseConnection:
     default_txn_name: str
 
     def cursor(
-        self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+        self,
+        *,
+        txn_name: Optional[str] = None,
+        after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+        exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
     ) -> "LoggingTransaction":
         if not txn_name:
             txn_name = self.default_txn_name
@@ -183,11 +191,16 @@ class LoggingDatabaseConnection:
         self.conn.__enter__()
         return self
 
-    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> Optional[bool]:
         return self.conn.__exit__(exc_type, exc_value, traceback)
 
     # Proxy through any unknown lookups to the DB conn class.
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:
         return getattr(self.conn, name)
 
 
@@ -391,17 +404,22 @@ class LoggingTransaction:
     def __enter__(self) -> "LoggingTransaction":
         return self
 
-    def __exit__(self, exc_type, exc_value, traceback):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> None:
         self.close()
 
 
 class PerformanceCounters:
-    def __init__(self):
-        self.current_counters = {}
-        self.previous_counters = {}
+    def __init__(self) -> None:
+        self.current_counters: Dict[str, Tuple[int, float]] = {}
+        self.previous_counters: Dict[str, Tuple[int, float]] = {}
 
     def update(self, key: str, duration_secs: float) -> None:
-        count, cum_time = self.current_counters.get(key, (0, 0))
+        count, cum_time = self.current_counters.get(key, (0, 0.0))
         count += 1
         cum_time += duration_secs
         self.current_counters[key] = (count, cum_time)
@@ -527,7 +545,7 @@ class DatabasePool:
     def start_profiling(self) -> None:
         self._previous_loop_ts = monotonic_time()
 
-        def loop():
+        def loop() -> None:
             curr = self._current_txn_total_time
             prev = self._previous_txn_total_time
             self._previous_txn_total_time = curr
@@ -1186,7 +1204,7 @@ class DatabasePool:
         if lock:
             self.engine.lock_table(txn, table)
 
-        def _getwhere(key):
+        def _getwhere(key: str) -> str:
             # If the value we're passing in is None (aka NULL), we need to use
             # IS, not =, as NULL = NULL equals NULL (False).
             if keyvalues[key] is None:
@@ -2258,7 +2276,7 @@ class DatabasePool:
         term: Optional[str],
         col: str,
         retcols: Collection[str],
-        desc="simple_search_list",
+        desc: str = "simple_search_list",
     ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ed29a0a5e2..42d484dc98 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -36,9 +36,8 @@ from prometheus_client import Counter
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
-from synapse.crypto.event_signing import compute_event_reference_hash
-from synapse.events import EventBase  # noqa: F401
-from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.events import EventBase, relation_from_event
+from synapse.events.snapshot import EventContext
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -50,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines.postgres import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 
@@ -129,7 +128,6 @@ class PersistEventsStore:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         *,
-        current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremities: Dict[str, Set[str]],
         use_negative_stream_ordering: bool = False,
@@ -140,8 +138,6 @@ class PersistEventsStore:
 
         Args:
             events_and_contexts:
-            current_state_for_room: Map from room_id to the current state of
-                the room based on forward extremities
             state_delta_for_room: Map from room_id to the delta to apply to
                 room state
             new_forward_extremities: Map from room_id to set of event IDs
@@ -216,9 +212,6 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-            for room_id, new_state in current_state_for_room.items():
-                self.store.get_current_state_ids.prefill((room_id,), new_state)
-
             for room_id, latest_event_ids in new_forward_extremities.items():
                 self.store.get_latest_event_ids_in_room.prefill(
                     (room_id,), list(latest_event_ids)
@@ -236,7 +229,9 @@ class PersistEventsStore:
         """
         results: List[str] = []
 
-        def _get_events_which_are_prevs_txn(txn, batch):
+        def _get_events_which_are_prevs_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             sql = """
             SELECT prev_event_id, internal_metadata
             FROM event_edges
@@ -286,7 +281,9 @@ class PersistEventsStore:
         # and their prev events.
         existing_prevs = set()
 
-        def _get_prevs_before_rejected_txn(txn, batch):
+        def _get_prevs_before_rejected_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             to_recursively_check = batch
 
             while to_recursively_check:
@@ -516,7 +513,7 @@ class PersistEventsStore:
     @classmethod
     def _add_chain_cover_index(
         cls,
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -810,7 +807,7 @@ class PersistEventsStore:
 
     @staticmethod
     def _allocate_chain_ids(
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -944,7 +941,7 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
         to_insert = []
@@ -998,7 +995,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         state_delta_by_room: Dict[str, DeltaState],
         stream_id: int,
-    ):
+    ) -> None:
         for room_id, delta_state in state_delta_by_room.items():
             to_delete = delta_state.to_delete
             to_insert = delta_state.to_insert
@@ -1156,7 +1153,7 @@ class PersistEventsStore:
                 txn, room_id, members_changed
             )
 
-    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
         events.
 
@@ -1190,7 +1187,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         new_forward_extremities: Dict[str, Set[str]],
         max_stream_order: int,
-    ):
+    ) -> None:
         for room_id in new_forward_extremities.keys():
             self.db_pool.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@@ -1255,9 +1252,9 @@ class PersistEventsStore:
 
     def _update_room_depths_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Update min_depth for each room
 
         Args:
@@ -1386,7 +1383,7 @@ class PersistEventsStore:
             # nothing to do here
             return
 
-        def event_dict(event):
+        def event_dict(event: EventBase) -> JsonDict:
             d = event.get_dict()
             d.pop("redacted", None)
             d.pop("redacted_because", None)
@@ -1477,18 +1474,20 @@ class PersistEventsStore:
             ),
         )
 
-    def _store_rejected_events_txn(self, txn, events_and_contexts):
+    def _store_rejected_events_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Add rows to the 'rejections' table for received events which were
         rejected
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
+            txn: db connection
+            events_and_contexts: events we are persisting
 
         Returns:
-            list[(EventBase, EventContext)] new list, without the rejected
-                events.
+            new list, without the rejected events.
         """
         # Remove the rejected events from the list now that we've added them
         # to the events table and the events_json table.
@@ -1509,7 +1508,7 @@ class PersistEventsStore:
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         all_events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool = False,
-    ):
+    ) -> None:
         """Update all the miscellaneous tables for new events
 
         Args:
@@ -1600,15 +1599,14 @@ class PersistEventsStore:
             inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
-        # Insert event_reference_hashes table.
-        self._store_event_reference_hashes_txn(
-            txn, [event for event, _ in events_and_contexts]
-        )
-
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
-    def _add_to_cache(self, txn, events_and_contexts):
+    def _add_to_cache(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         to_prefill = []
 
         rows = []
@@ -1639,7 +1637,7 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        def prefill():
+        def prefill() -> None:
             for cache_entry in to_prefill:
                 self.store._get_event_cache.set(
                     (cache_entry.event.event_id,), cache_entry
@@ -1669,19 +1667,24 @@ class PersistEventsStore:
         )
 
     def insert_labels_for_event_txn(
-        self, txn, event_id, labels, room_id, topological_ordering
-    ):
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        labels: List[str],
+        room_id: str,
+        topological_ordering: int,
+    ) -> None:
         """Store the mapping between an event's ID and its labels, with one row per
         (event_id, label) tuple.
 
         Args:
-            txn (LoggingTransaction): The transaction to execute.
-            event_id (str): The event's ID.
-            labels (list[str]): A list of text labels.
-            room_id (str): The ID of the room the event was sent to.
-            topological_ordering (int): The position of the event in the room's topology.
+            txn: The transaction to execute.
+            event_id: The event's ID.
+            labels: A list of text labels.
+            room_id: The ID of the room the event was sent to.
+            topological_ordering: The position of the event in the room's topology.
         """
-        return self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_insert_many_txn(
             txn=txn,
             table="event_labels",
             keys=("event_id", "label", "room_id", "topological_ordering"),
@@ -1690,44 +1693,32 @@ class PersistEventsStore:
             ],
         )
 
-    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+    def _insert_event_expiry_txn(
+        self, txn: LoggingTransaction, event_id: str, expiry_ts: int
+    ) -> None:
         """Save the expiry timestamp associated with a given event ID.
 
         Args:
-            txn (LoggingTransaction): The database transaction to use.
-            event_id (str): The event ID the expiry timestamp is associated with.
-            expiry_ts (int): The timestamp at which to expire (delete) the event.
+            txn: The database transaction to use.
+            event_id: The event ID the expiry timestamp is associated with.
+            expiry_ts: The timestamp at which to expire (delete) the event.
         """
-        return self.db_pool.simple_insert_txn(
+        self.db_pool.simple_insert_txn(
             txn=txn,
             table="event_expiry",
             values={"event_id": event_id, "expiry_ts": expiry_ts},
         )
 
-    def _store_event_reference_hashes_txn(self, txn, events):
-        """Store a hash for a PDU
-        Args:
-            txn (cursor):
-            events (list): list of Events.
-        """
-
-        vals = []
-        for event in events:
-            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
-            vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
-
-        self.db_pool.simple_insert_many_txn(
-            txn,
-            table="event_reference_hashes",
-            keys=("event_id", "algorithm", "hash"),
-            values=vals,
-        )
-
     def _store_room_members_txn(
-        self, txn, events, *, inhibit_local_membership_updates: bool = False
-    ):
+        self,
+        txn: LoggingTransaction,
+        events: List[EventBase],
+        *,
+        inhibit_local_membership_updates: bool = False,
+    ) -> None:
         """
         Store a room member in the database.
+
         Args:
             txn: The transaction to use.
             events: List of events to store.
@@ -1767,6 +1758,7 @@ class PersistEventsStore:
         )
 
         for event in events:
+            assert event.internal_metadata.stream_ordering is not None
             txn.call_after(
                 self.store._membership_stream_cache.entity_has_changed,
                 event.state_key,
@@ -1815,55 +1807,50 @@ class PersistEventsStore:
             txn: The current database transaction.
             event: The event which might have relations.
         """
-        relation = event.content.get("m.relates_to")
+        relation = relation_from_event(event)
         if not relation:
-            # No relations
-            return
-
-        # Relations must have a type and parent event ID.
-        rel_type = relation.get("rel_type")
-        if not isinstance(rel_type, str):
+            # No relation, nothing to do.
             return
 
-        parent_id = relation.get("event_id")
-        if not isinstance(parent_id, str):
-            return
-
-        # Annotations have a key field.
-        aggregation_key = None
-        if rel_type == RelationTypes.ANNOTATION:
-            aggregation_key = relation.get("key")
-
         self.db_pool.simple_insert_txn(
             txn,
             table="event_relations",
             values={
                 "event_id": event.event_id,
-                "relates_to_id": parent_id,
-                "relation_type": rel_type,
-                "aggregation_key": aggregation_key,
+                "relates_to_id": relation.parent_id,
+                "relation_type": relation.rel_type,
+                "aggregation_key": relation.aggregation_key,
             },
         )
 
-        txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
         txn.call_after(
-            self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
+            self.store.get_relations_for_event.invalidate, (relation.parent_id,)
+        )
+        txn.call_after(
+            self.store.get_aggregation_groups_for_event.invalidate,
+            (relation.parent_id,),
         )
 
-        if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.REPLACE:
+            txn.call_after(
+                self.store.get_applicable_edit.invalidate, (relation.parent_id,)
+            )
 
-        if rel_type == RelationTypes.THREAD:
-            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.THREAD:
+            txn.call_after(
+                self.store.get_thread_summary.invalidate, (relation.parent_id,)
+            )
             # It should be safe to only invalidate the cache if the user has not
             # previously participated in the thread, but that's difficult (and
             # potentially error-prone) so it is always invalidated.
             txn.call_after(
                 self.store.get_thread_participated.invalidate,
-                (parent_id, event.sender),
+                (relation.parent_id, event.sender),
             )
 
-    def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_insertion_event(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         """Handles keeping track of insertion events and edges/connections.
         Part of MSC2716.
 
@@ -1924,7 +1911,7 @@ class PersistEventsStore:
                 },
             )
 
-    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
         """Handles inserting the batch edges/connections between the batch event
         and an insertion event. Part of MSC2716.
 
@@ -2024,25 +2011,29 @@ class PersistEventsStore:
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
 
-    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("topic"), str):
             self.store_event_search_txn(
                 txn, event, "content.topic", event.content["topic"]
             )
 
-    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("name"), str):
             self.store_event_search_txn(
                 txn, event, "content.name", event.content["name"]
             )
 
-    def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_message_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if isinstance(event.content.get("body"), str):
             self.store_event_search_txn(
                 txn, event, "content.body", event.content["body"]
             )
 
-    def _store_retention_policy_for_room_txn(self, txn, event):
+    def _store_retention_policy_for_room_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if not event.is_state():
             logger.debug("Ignoring non-state m.room.retention event")
             return
@@ -2102,8 +2093,11 @@ class PersistEventsStore:
         )
 
     def _set_push_actions_for_event_and_users_txn(
-        self, txn, events_and_contexts, all_events_and_contexts
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        all_events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         """Handles moving push actions from staging table to main
         event_push_actions table for all events in `events_and_contexts`.
 
@@ -2111,12 +2105,10 @@ class PersistEventsStore:
         from the push action staging area.
 
         Args:
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            all_events_and_contexts (list[(EventBase, EventContext)]): all
-                events that we were going to persist. This includes events
-                we've already persisted, etc, that wouldn't appear in
-                events_and_context.
+            events_and_contexts: events we are persisting
+            all_events_and_contexts: all events that we were going to persist.
+                This includes events we've already persisted, etc, that wouldn't
+                appear in events_and_context.
         """
 
         # Only non outlier events will have push actions associated with them,
@@ -2185,7 +2177,9 @@ class PersistEventsStore:
             ),
         )
 
-    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+    def _remove_push_actions_for_event_id_txn(
+        self, txn: LoggingTransaction, room_id: str, event_id: str
+    ) -> None:
         # Sad that we have to blow away the cache for the whole room here
         txn.call_after(
             self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@@ -2196,7 +2190,9 @@ class PersistEventsStore:
             (room_id, event_id),
         )
 
-    def _store_rejections_txn(self, txn, event_id, reason):
+    def _store_rejections_txn(
+        self, txn: LoggingTransaction, event_id: str, reason: str
+    ) -> None:
         self.db_pool.simple_insert_txn(
             txn,
             table="rejections",
@@ -2208,8 +2204,10 @@ class PersistEventsStore:
         )
 
     def _store_event_state_mappings_txn(
-        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+    ) -> None:
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
@@ -2266,7 +2264,9 @@ class PersistEventsStore:
                 state_group_id,
             )
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+    def _update_min_depth_for_room_txn(
+        self, txn: LoggingTransaction, room_id: str, depth: int
+    ) -> None:
         min_depth = self.store._get_min_depth_interaction(txn, room_id)
 
         if min_depth is not None and depth >= min_depth:
@@ -2279,7 +2279,9 @@ class PersistEventsStore:
             values={"min_depth": depth},
         )
 
-    def _handle_mult_prev_events(self, txn, events):
+    def _handle_mult_prev_events(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """
         For the given event, update the event edges table and forward and
         backward extremities tables.
@@ -2297,7 +2299,9 @@ class PersistEventsStore:
 
         self._update_backward_extremeties(txn, events)
 
-    def _update_backward_extremeties(self, txn, events):
+    def _update_backward_extremeties(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """Updates the event_backward_extremities tables based on the new/updated
         events being persisted.
 
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 1480a0f048..d03555a585 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
+from synapse.storage.types import Cursor
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         self._last_user_visit_update = self._get_start_of_day()
 
     @wrap_as_background_process("read_forward_extremities")
-    async def _read_forward_extremities(self):
+    async def _read_forward_extremities(self) -> None:
         def fetch(txn):
             txn.execute(
                 """
@@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (x[0] - 1) * x[1] for x in res if x[1]
         )
 
-    async def count_daily_e2ee_messages(self):
+    async def count_daily_e2ee_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
 
-    async def count_daily_sent_e2ee_messages(self):
+    async def count_daily_sent_e2ee_messages(self) -> int:
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
@@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_sent_e2ee_messages", _count_messages
         )
 
-    async def count_daily_active_e2ee_rooms(self):
+    async def count_daily_active_e2ee_rooms(self) -> int:
         def _count(txn):
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
@@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_active_e2ee_rooms", _count
         )
 
-    async def count_daily_messages(self):
+    async def count_daily_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_messages", _count_messages)
 
-    async def count_daily_sent_messages(self):
+    async def count_daily_sent_messages(self) -> int:
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
@@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_sent_messages", _count_messages
         )
 
-    async def count_daily_active_rooms(self):
+    async def count_daily_active_rooms(self) -> int:
         def _count(txn):
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
@@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
-    def _count_users(self, txn, time_from):
+    def _count_users(self, txn: Cursor, time_from: int) -> int:
         """
         Returns number of users seen in the past time_from period
         """
@@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             ) u
         """
         txn.execute(sql, (time_from,))
-        (count,) = txn.fetchone()
+        # Mypy knows that fetchone() might return None if there are no rows.
+        # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
+        # returns exactly one row.
+        (count,) = txn.fetchone()  # type: ignore[misc]
         return count
 
     async def count_r30_users(self) -> Dict[str, int]:
@@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_r30v2_users", _count_r30v2_users
         )
 
-    def _get_start_of_day(self):
+    def _get_start_of_day(self) -> int:
         """
         Returns millisecond unixtime for start of UTC day.
         """
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index bfc85b3add..38ba91af4c 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,7 +69,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         #     event_forward_extremities
         #     event_json
         #     event_push_actions
-        #     event_reference_hashes
         #     event_relations
         #     event_search
         #     event_to_state_groups
@@ -220,7 +219,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_auth",
             "event_edges",
             "event_forward_extremities",
-            "event_reference_hashes",
             "event_relations",
             "event_search",
             "rejections",
@@ -369,7 +367,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_edges",
             "event_json",
             "event_push_actions_staging",
-            "event_reference_hashes",
             "event_relations",
             "event_to_state_groups",
             "event_auth_chains",
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3c49e7ec98..78e0773b2a 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
 
 import attr
 
@@ -27,7 +27,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
             self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
         )
 
-    async def _background_reindex_search(self, progress, batch_size):
+    async def _background_reindex_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
@@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT stream_ordering, event_id, room_id, type, json, "
                 " origin_server_ts FROM events"
@@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         return result
 
-    async def _background_reindex_gin_search(self, progress, batch_size):
+    async def _background_reindex_gin_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """This handles old synapses which used GIST indexes, if any;
         converting them back to be GIN as per the actual schema.
         """
 
-        def create_index(conn):
+        def create_index(conn: LoggingDatabaseConnection) -> None:
             conn.rollback()
 
             # we have to set autocommit, because postgres refuses to
@@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
         )
         return 1
 
-    async def _background_reindex_search_order(self, progress, batch_size):
+    async def _background_reindex_search_order(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         if not have_added_index:
 
-            def create_index(conn):
+            def create_index(conn: LoggingDatabaseConnection) -> None:
                 conn.rollback()
                 conn.set_session(autocommit=True)
                 c = conn.cursor()
@@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
                 pg,
             )
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
             sql = (
                 "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
                 " origin_server_ts = e.origin_server_ts"
@@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
         else:
             raise Exception("Unrecognized database engine")
 
-        args.append(limit)
+        # mypy expects to append only a `str`, not an `int`
+        args.append(limit)  # type: ignore[arg-type]
 
         results = await self.db_pool.execute(
             "search_rooms", self.db_pool.cursor_to_dict, sql, *args
@@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
             A set of strings.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> Set[str]:
             highlight_words = set()
             for event in events:
                 # As a hack we simply join values of all possible keys. This is
@@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
         return await self.db_pool.runInteraction("_find_highlights", f)
 
 
-def _to_postgres_options(options_dict):
+def _to_postgres_options(options_dict: JsonDict) -> str:
     return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
 
 
-def _parse_query(database_engine, search_term):
+def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
     """Takes a plain unicode string from the user and converts it into a form
     that can be passed to database.
     We use this so that we can add prefix matching, which isn't something
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 3c3137fe64..944e2a5b0b 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -788,30 +788,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return None
 
     async def get_current_room_stream_token_for_room_id(
-        self, room_id: Optional[str] = None
+        self, room_id: str
     ) -> RoomStreamToken:
-        """Returns the current position of the rooms stream.
-
-        By default, it returns a live token with the current global stream
-        token. Specifying a `room_id` causes it to return a historic token with
-        the room specific topological token.
-        """
+        """Returns the current position of the rooms stream (historic token)."""
         stream_ordering = self.get_room_max_stream_ordering()
-        if room_id is None:
-            return RoomStreamToken(None, stream_ordering)
-        else:
-            topo = await self.db_pool.runInteraction(
-                "_get_max_topological_txn", self._get_max_topological_txn, room_id
-            )
-            return RoomStreamToken(topo, stream_ordering)
+        topo = await self.db_pool.runInteraction(
+            "_get_max_topological_txn", self._get_max_topological_txn, room_id
+        )
+        return RoomStreamToken(topo, stream_ordering)
 
     def get_stream_id_for_event_txn(
         self,
         txn: LoggingTransaction,
         event_id: str,
-        allow_none=False,
-    ) -> int:
-        return self.db_pool.simple_select_one_onecol_txn(
+        allow_none: bool = False,
+    ) -> Optional[int]:
+        # Type ignore: we pass keyvalues a Dict[str, str]; the function wants
+        # Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
+        return self.db_pool.simple_select_one_onecol_txn(  # type: ignore[call-overload]
             txn=txn,
             table="events",
             keyvalues={"event_id": event_id},
@@ -873,7 +867,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         rows = txn.fetchall()
-        return rows[0][0] if rows else 0
+        # An aggregate function like MAX() will always return one row per group
+        # so we can safely rely on the lookup here. For example, when a we
+        # lookup a `room_id` which does not exist, `rows` will look like
+        # `[(None,)]`
+        return rows[0][0] if rows[0][0] is not None else 0
 
     @staticmethod
     def _set_before_and_after(
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 97118045a1..0fc282866b 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -25,6 +25,7 @@ from typing import (
     Collection,
     Deque,
     Dict,
+    Generator,
     Generic,
     Iterable,
     List,
@@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
 
         return res
 
-    def _handle_queue(self, room_id):
+    def _handle_queue(self, room_id: str) -> None:
         """Attempts to handle the queue for a room if not already being handled.
 
         The queue's callback will be invoked with for each item in the queue,
@@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
 
         self._currently_persisting_rooms.add(room_id)
 
-        async def handle_queue_loop():
+        async def handle_queue_loop() -> None:
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
@@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                         with PreserveLoggingContext():
                             item.deferred.callback(ret)
             finally:
-                queue = self._event_persist_queues.pop(room_id, None)
-                if queue:
-                    self._event_persist_queues[room_id] = queue
+                remaining_queue = self._event_persist_queues.pop(room_id, None)
+                if remaining_queue:
+                    self._event_persist_queues[room_id] = remaining_queue
                 self._currently_persisting_rooms.discard(room_id)
 
         # set handle_queue_loop off in the background
         run_as_background_process("persist_events", handle_queue_loop)
 
-    def _get_drainining_queue(self, room_id):
+    def _get_drainining_queue(
+        self, room_id: str
+    ) -> Generator[_EventPersistQueueItem, None, None]:
         queue = self._event_persist_queues.setdefault(room_id, deque())
 
         try:
@@ -317,7 +320,9 @@ class EventsPersistenceStorage:
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
-        async def enqueue(item):
+        async def enqueue(
+            item: Tuple[str, List[Tuple[EventBase, EventContext]]]
+        ) -> Dict[str, str]:
             room_id, evs_ctxs = item
             return await self._event_persist_queue.add_to_queue(
                 room_id, evs_ctxs, backfilled=backfilled
@@ -487,12 +492,6 @@ class EventsPersistenceStorage:
             # extremities in each room
             new_forward_extremities: Dict[str, Set[str]] = {}
 
-            # map room_id->(type,state_key)->event_id tracking the full
-            # state in each room after adding these events.
-            # This is simply used to prefill the get_current_state_ids
-            # cache
-            current_state_for_room: Dict[str, StateMap[str]] = {}
-
             # map room_id->(to_delete, to_insert) where to_delete is a list
             # of type/state keys to remove from current state, and to_insert
             # is a map (type,key)->event_id giving the state delta in each
@@ -628,14 +627,8 @@ class EventsPersistenceStorage:
 
                             state_delta_for_room[room_id] = delta
 
-                        # If we have the current_state then lets prefill
-                        # the cache with it.
-                        if current_state is not None:
-                            current_state_for_room[room_id] = current_state
-
             await self.persist_events_store._persist_events_and_state_updates(
                 chunk,
-                current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
                 use_negative_stream_ordering=backfilled,
@@ -733,7 +726,8 @@ class EventsPersistenceStorage:
 
             The first state map is the full new current state and the second
             is the delta to the existing current state. If both are None then
-            there has been no change.
+            there has been no change. Either or neither can be None if there
+            has been a change.
 
             The function may prune some old entries from the set of new
             forward extremities if it's safe to do so.
@@ -743,9 +737,6 @@ class EventsPersistenceStorage:
             the new current state is only returned if we've already calculated
             it.
         """
-        # map from state_group to ((type, key) -> event_id) state map
-        state_groups_map = {}
-
         # Map from (prev state group, new state group) -> delta state dict
         state_group_deltas = {}
 
@@ -759,16 +750,6 @@ class EventsPersistenceStorage:
                     )
                 continue
 
-            if ctx.state_group in state_groups_map:
-                continue
-
-            # We're only interested in pulling out state that has already
-            # been cached in the context. We'll pull stuff out of the DB later
-            # if necessary.
-            current_state_ids = ctx.get_cached_current_state_ids()
-            if current_state_ids is not None:
-                state_groups_map[ctx.state_group] = current_state_ids
-
             if ctx.prev_group:
                 state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
 
@@ -826,18 +807,14 @@ class EventsPersistenceStorage:
             delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
             if delta_ids is not None:
                 # We have a delta from the existing to new current state,
-                # so lets just return that. If we happen to already have
-                # the current state in memory then lets also return that,
-                # but it doesn't matter if we don't.
-                new_state = state_groups_map.get(new_state_group)
-                return new_state, delta_ids, new_latest_event_ids
+                # so lets just return that.
+                return None, delta_ids, new_latest_event_ids
 
         # Now that we have calculated new_state_groups we need to get
         # their state IDs so we can resolve to a single state set.
-        missing_state = new_state_groups - set(state_groups_map)
-        if missing_state:
-            group_to_state = await self.state_store._get_state_for_groups(missing_state)
-            state_groups_map.update(group_to_state)
+        state_groups_map = await self.state_store._get_state_for_groups(
+            new_state_groups
+        )
 
         if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
@@ -1130,7 +1107,7 @@ class EventsPersistenceStorage:
 
         return False
 
-    async def _handle_potentially_left_users(self, user_ids: Set[str]):
+    async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
         """Given a set of remote users check if the server still shares a room with
         them. If not then mark those users' device cache as stale.
         """
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 546d6bae6e..c33df42084 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -85,7 +85,7 @@ def prepare_database(
     database_engine: BaseDatabaseEngine,
     config: Optional[HomeServerConfig],
     databases: Collection[str] = ("main", "state"),
-):
+) -> None:
     """Prepares a physical database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
 
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 871d4ace12..20c344faea 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 69  # remember to update the list below when updating
+SCHEMA_VERSION = 70  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -62,6 +62,9 @@ Changes in SCHEMA_VERSION = 68:
 Changes in SCHEMA_VERSION = 69:
     - We now write to `device_lists_changes_in_room` table.
     - Use sequence to generate future `application_services_txns.txn_id`s
+
+Changes in SCHEMA_VERSION = 70:
+    - event_reference_hashes is no longer written to.
 """
 
 
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index d1d5859214..d4a1bd4f9d 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -62,7 +62,7 @@ class StateFilter:
     types: "frozendict[str, Optional[FrozenSet[str]]]"
     include_others: bool = False
 
-    def __attrs_post_init__(self):
+    def __attrs_post_init__(self) -> None:
         # If `include_others` is set we canonicalise the filter by removing
         # wildcards from the types dictionary
         if self.include_others:
@@ -138,7 +138,9 @@ class StateFilter:
         )
 
     @staticmethod
-    def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
+    def freeze(
+        types: Mapping[str, Optional[Collection[str]]], include_others: bool
+    ) -> "StateFilter":
         """
         Returns a (frozen) StateFilter with the same contents as the parameters
         specified here, which can be made of mutable types.
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index d7d6f1d90e..40536c1830 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -11,7 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
+from types import TracebackType
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
 
 from typing_extensions import Protocol
 
@@ -86,5 +87,10 @@ class Connection(Protocol):
     def __enter__(self) -> "Connection":
         ...
 
-    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> Optional[bool]:
         ...