summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/cache.py8
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py4
-rw-r--r--synapse/storage/databases/main/events.py236
-rw-r--r--synapse/storage/databases/main/events_worker.py35
-rw-r--r--synapse/storage/databases/main/metrics.py24
-rw-r--r--synapse/storage/databases/main/purge_events.py3
-rw-r--r--synapse/storage/databases/main/relations.py6
-rw-r--r--synapse/storage/databases/main/search.py33
-rw-r--r--synapse/storage/databases/main/stream.py34
9 files changed, 217 insertions, 166 deletions
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index dd4e83a2ad..1653a6a9b6 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -57,6 +57,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self._instance_name = hs.get_instance_name()
 
+        self.db_pool.updates.register_background_index_update(
+            update_name="cache_invalidation_index_by_instance",
+            index_name="cache_invalidation_stream_by_instance_instance_index",
+            table="cache_invalidation_stream_by_instance",
+            columns=("instance_name", "stream_id"),
+            psql_only=True,  # The table is only on postgres DBs.
+        )
+
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b789a588a5..af59be6b48 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -21,7 +21,7 @@ from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import LoggingTransaction
-from synapse.types import JsonDict, JsonSerializable
+from synapse.types import JsonDict, JsonSerializable, StreamKeyType
 from synapse.util import json_encoder
 
 
@@ -126,7 +126,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                     "message": "Set room key",
                     "room_id": room_id,
                     "session_id": session_id,
-                    "room_key": room_key,
+                    StreamKeyType.ROOM: room_key,
                 }
             )
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ed29a0a5e2..42d484dc98 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -36,9 +36,8 @@ from prometheus_client import Counter
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
-from synapse.crypto.event_signing import compute_event_reference_hash
-from synapse.events import EventBase  # noqa: F401
-from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.events import EventBase, relation_from_event
+from synapse.events.snapshot import EventContext
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -50,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines.postgres import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 
@@ -129,7 +128,6 @@ class PersistEventsStore:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         *,
-        current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremities: Dict[str, Set[str]],
         use_negative_stream_ordering: bool = False,
@@ -140,8 +138,6 @@ class PersistEventsStore:
 
         Args:
             events_and_contexts:
-            current_state_for_room: Map from room_id to the current state of
-                the room based on forward extremities
             state_delta_for_room: Map from room_id to the delta to apply to
                 room state
             new_forward_extremities: Map from room_id to set of event IDs
@@ -216,9 +212,6 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-            for room_id, new_state in current_state_for_room.items():
-                self.store.get_current_state_ids.prefill((room_id,), new_state)
-
             for room_id, latest_event_ids in new_forward_extremities.items():
                 self.store.get_latest_event_ids_in_room.prefill(
                     (room_id,), list(latest_event_ids)
@@ -236,7 +229,9 @@ class PersistEventsStore:
         """
         results: List[str] = []
 
-        def _get_events_which_are_prevs_txn(txn, batch):
+        def _get_events_which_are_prevs_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             sql = """
             SELECT prev_event_id, internal_metadata
             FROM event_edges
@@ -286,7 +281,9 @@ class PersistEventsStore:
         # and their prev events.
         existing_prevs = set()
 
-        def _get_prevs_before_rejected_txn(txn, batch):
+        def _get_prevs_before_rejected_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             to_recursively_check = batch
 
             while to_recursively_check:
@@ -516,7 +513,7 @@ class PersistEventsStore:
     @classmethod
     def _add_chain_cover_index(
         cls,
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -810,7 +807,7 @@ class PersistEventsStore:
 
     @staticmethod
     def _allocate_chain_ids(
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -944,7 +941,7 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
         to_insert = []
@@ -998,7 +995,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         state_delta_by_room: Dict[str, DeltaState],
         stream_id: int,
-    ):
+    ) -> None:
         for room_id, delta_state in state_delta_by_room.items():
             to_delete = delta_state.to_delete
             to_insert = delta_state.to_insert
@@ -1156,7 +1153,7 @@ class PersistEventsStore:
                 txn, room_id, members_changed
             )
 
-    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
         events.
 
@@ -1190,7 +1187,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         new_forward_extremities: Dict[str, Set[str]],
         max_stream_order: int,
-    ):
+    ) -> None:
         for room_id in new_forward_extremities.keys():
             self.db_pool.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@@ -1255,9 +1252,9 @@ class PersistEventsStore:
 
     def _update_room_depths_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Update min_depth for each room
 
         Args:
@@ -1386,7 +1383,7 @@ class PersistEventsStore:
             # nothing to do here
             return
 
-        def event_dict(event):
+        def event_dict(event: EventBase) -> JsonDict:
             d = event.get_dict()
             d.pop("redacted", None)
             d.pop("redacted_because", None)
@@ -1477,18 +1474,20 @@ class PersistEventsStore:
             ),
         )
 
-    def _store_rejected_events_txn(self, txn, events_and_contexts):
+    def _store_rejected_events_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Add rows to the 'rejections' table for received events which were
         rejected
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
+            txn: db connection
+            events_and_contexts: events we are persisting
 
         Returns:
-            list[(EventBase, EventContext)] new list, without the rejected
-                events.
+            new list, without the rejected events.
         """
         # Remove the rejected events from the list now that we've added them
         # to the events table and the events_json table.
@@ -1509,7 +1508,7 @@ class PersistEventsStore:
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         all_events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool = False,
-    ):
+    ) -> None:
         """Update all the miscellaneous tables for new events
 
         Args:
@@ -1600,15 +1599,14 @@ class PersistEventsStore:
             inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
-        # Insert event_reference_hashes table.
-        self._store_event_reference_hashes_txn(
-            txn, [event for event, _ in events_and_contexts]
-        )
-
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
-    def _add_to_cache(self, txn, events_and_contexts):
+    def _add_to_cache(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         to_prefill = []
 
         rows = []
@@ -1639,7 +1637,7 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        def prefill():
+        def prefill() -> None:
             for cache_entry in to_prefill:
                 self.store._get_event_cache.set(
                     (cache_entry.event.event_id,), cache_entry
@@ -1669,19 +1667,24 @@ class PersistEventsStore:
         )
 
     def insert_labels_for_event_txn(
-        self, txn, event_id, labels, room_id, topological_ordering
-    ):
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        labels: List[str],
+        room_id: str,
+        topological_ordering: int,
+    ) -> None:
         """Store the mapping between an event's ID and its labels, with one row per
         (event_id, label) tuple.
 
         Args:
-            txn (LoggingTransaction): The transaction to execute.
-            event_id (str): The event's ID.
-            labels (list[str]): A list of text labels.
-            room_id (str): The ID of the room the event was sent to.
-            topological_ordering (int): The position of the event in the room's topology.
+            txn: The transaction to execute.
+            event_id: The event's ID.
+            labels: A list of text labels.
+            room_id: The ID of the room the event was sent to.
+            topological_ordering: The position of the event in the room's topology.
         """
-        return self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_insert_many_txn(
             txn=txn,
             table="event_labels",
             keys=("event_id", "label", "room_id", "topological_ordering"),
@@ -1690,44 +1693,32 @@ class PersistEventsStore:
             ],
         )
 
-    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+    def _insert_event_expiry_txn(
+        self, txn: LoggingTransaction, event_id: str, expiry_ts: int
+    ) -> None:
         """Save the expiry timestamp associated with a given event ID.
 
         Args:
-            txn (LoggingTransaction): The database transaction to use.
-            event_id (str): The event ID the expiry timestamp is associated with.
-            expiry_ts (int): The timestamp at which to expire (delete) the event.
+            txn: The database transaction to use.
+            event_id: The event ID the expiry timestamp is associated with.
+            expiry_ts: The timestamp at which to expire (delete) the event.
         """
-        return self.db_pool.simple_insert_txn(
+        self.db_pool.simple_insert_txn(
             txn=txn,
             table="event_expiry",
             values={"event_id": event_id, "expiry_ts": expiry_ts},
         )
 
-    def _store_event_reference_hashes_txn(self, txn, events):
-        """Store a hash for a PDU
-        Args:
-            txn (cursor):
-            events (list): list of Events.
-        """
-
-        vals = []
-        for event in events:
-            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
-            vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
-
-        self.db_pool.simple_insert_many_txn(
-            txn,
-            table="event_reference_hashes",
-            keys=("event_id", "algorithm", "hash"),
-            values=vals,
-        )
-
     def _store_room_members_txn(
-        self, txn, events, *, inhibit_local_membership_updates: bool = False
-    ):
+        self,
+        txn: LoggingTransaction,
+        events: List[EventBase],
+        *,
+        inhibit_local_membership_updates: bool = False,
+    ) -> None:
         """
         Store a room member in the database.
+
         Args:
             txn: The transaction to use.
             events: List of events to store.
@@ -1767,6 +1758,7 @@ class PersistEventsStore:
         )
 
         for event in events:
+            assert event.internal_metadata.stream_ordering is not None
             txn.call_after(
                 self.store._membership_stream_cache.entity_has_changed,
                 event.state_key,
@@ -1815,55 +1807,50 @@ class PersistEventsStore:
             txn: The current database transaction.
             event: The event which might have relations.
         """
-        relation = event.content.get("m.relates_to")
+        relation = relation_from_event(event)
         if not relation:
-            # No relations
-            return
-
-        # Relations must have a type and parent event ID.
-        rel_type = relation.get("rel_type")
-        if not isinstance(rel_type, str):
+            # No relation, nothing to do.
             return
 
-        parent_id = relation.get("event_id")
-        if not isinstance(parent_id, str):
-            return
-
-        # Annotations have a key field.
-        aggregation_key = None
-        if rel_type == RelationTypes.ANNOTATION:
-            aggregation_key = relation.get("key")
-
         self.db_pool.simple_insert_txn(
             txn,
             table="event_relations",
             values={
                 "event_id": event.event_id,
-                "relates_to_id": parent_id,
-                "relation_type": rel_type,
-                "aggregation_key": aggregation_key,
+                "relates_to_id": relation.parent_id,
+                "relation_type": relation.rel_type,
+                "aggregation_key": relation.aggregation_key,
             },
         )
 
-        txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
         txn.call_after(
-            self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
+            self.store.get_relations_for_event.invalidate, (relation.parent_id,)
+        )
+        txn.call_after(
+            self.store.get_aggregation_groups_for_event.invalidate,
+            (relation.parent_id,),
         )
 
-        if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.REPLACE:
+            txn.call_after(
+                self.store.get_applicable_edit.invalidate, (relation.parent_id,)
+            )
 
-        if rel_type == RelationTypes.THREAD:
-            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.THREAD:
+            txn.call_after(
+                self.store.get_thread_summary.invalidate, (relation.parent_id,)
+            )
             # It should be safe to only invalidate the cache if the user has not
             # previously participated in the thread, but that's difficult (and
             # potentially error-prone) so it is always invalidated.
             txn.call_after(
                 self.store.get_thread_participated.invalidate,
-                (parent_id, event.sender),
+                (relation.parent_id, event.sender),
             )
 
-    def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_insertion_event(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         """Handles keeping track of insertion events and edges/connections.
         Part of MSC2716.
 
@@ -1924,7 +1911,7 @@ class PersistEventsStore:
                 },
             )
 
-    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
         """Handles inserting the batch edges/connections between the batch event
         and an insertion event. Part of MSC2716.
 
@@ -2024,25 +2011,29 @@ class PersistEventsStore:
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
 
-    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("topic"), str):
             self.store_event_search_txn(
                 txn, event, "content.topic", event.content["topic"]
             )
 
-    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("name"), str):
             self.store_event_search_txn(
                 txn, event, "content.name", event.content["name"]
             )
 
-    def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_message_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if isinstance(event.content.get("body"), str):
             self.store_event_search_txn(
                 txn, event, "content.body", event.content["body"]
             )
 
-    def _store_retention_policy_for_room_txn(self, txn, event):
+    def _store_retention_policy_for_room_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if not event.is_state():
             logger.debug("Ignoring non-state m.room.retention event")
             return
@@ -2102,8 +2093,11 @@ class PersistEventsStore:
         )
 
     def _set_push_actions_for_event_and_users_txn(
-        self, txn, events_and_contexts, all_events_and_contexts
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        all_events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         """Handles moving push actions from staging table to main
         event_push_actions table for all events in `events_and_contexts`.
 
@@ -2111,12 +2105,10 @@ class PersistEventsStore:
         from the push action staging area.
 
         Args:
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            all_events_and_contexts (list[(EventBase, EventContext)]): all
-                events that we were going to persist. This includes events
-                we've already persisted, etc, that wouldn't appear in
-                events_and_context.
+            events_and_contexts: events we are persisting
+            all_events_and_contexts: all events that we were going to persist.
+                This includes events we've already persisted, etc, that wouldn't
+                appear in events_and_context.
         """
 
         # Only non outlier events will have push actions associated with them,
@@ -2185,7 +2177,9 @@ class PersistEventsStore:
             ),
         )
 
-    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+    def _remove_push_actions_for_event_id_txn(
+        self, txn: LoggingTransaction, room_id: str, event_id: str
+    ) -> None:
         # Sad that we have to blow away the cache for the whole room here
         txn.call_after(
             self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@@ -2196,7 +2190,9 @@ class PersistEventsStore:
             (room_id, event_id),
         )
 
-    def _store_rejections_txn(self, txn, event_id, reason):
+    def _store_rejections_txn(
+        self, txn: LoggingTransaction, event_id: str, reason: str
+    ) -> None:
         self.db_pool.simple_insert_txn(
             txn,
             table="rejections",
@@ -2208,8 +2204,10 @@ class PersistEventsStore:
         )
 
     def _store_event_state_mappings_txn(
-        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+    ) -> None:
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
@@ -2266,7 +2264,9 @@ class PersistEventsStore:
                 state_group_id,
             )
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+    def _update_min_depth_for_room_txn(
+        self, txn: LoggingTransaction, room_id: str, depth: int
+    ) -> None:
         min_depth = self.store._get_min_depth_interaction(txn, room_id)
 
         if min_depth is not None and depth >= min_depth:
@@ -2279,7 +2279,9 @@ class PersistEventsStore:
             values={"min_depth": depth},
         )
 
-    def _handle_mult_prev_events(self, txn, events):
+    def _handle_mult_prev_events(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """
         For the given event, update the event edges table and forward and
         backward extremities tables.
@@ -2297,7 +2299,9 @@ class PersistEventsStore:
 
         self._update_backward_extremeties(txn, events)
 
-    def _update_backward_extremeties(self, txn, events):
+    def _update_backward_extremeties(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """Updates the event_backward_extremities tables based on the new/updated
         events being persisted.
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a4a604a499..5b22d6b452 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -14,6 +14,7 @@
 
 import logging
 import threading
+import weakref
 from enum import Enum, auto
 from typing import (
     TYPE_CHECKING,
@@ -23,6 +24,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    MutableMapping,
     Optional,
     Set,
     Tuple,
@@ -248,6 +250,12 @@ class EventsWorkerStore(SQLBaseStore):
             str, ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = {}
 
+        # We keep track of the events we have currently loaded in memory so that
+        # we can reuse them even if they've been evicted from the cache. We only
+        # track events that don't need redacting in here (as then we don't need
+        # to track redaction status).
+        self._event_ref: MutableMapping[str, EventBase] = weakref.WeakValueDictionary()
+
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list: List[
             Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
@@ -723,6 +731,8 @@ class EventsWorkerStore(SQLBaseStore):
 
     def _invalidate_get_event_cache(self, event_id: str) -> None:
         self._get_event_cache.invalidate((event_id,))
+        self._event_ref.pop(event_id, None)
+        self._current_event_fetches.pop(event_id, None)
 
     def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
@@ -738,13 +748,30 @@ class EventsWorkerStore(SQLBaseStore):
         event_map = {}
 
         for event_id in events:
+            # First check if it's in the event cache
             ret = self._get_event_cache.get(
                 (event_id,), None, update_metrics=update_metrics
             )
-            if not ret:
+            if ret:
+                event_map[event_id] = ret
                 continue
 
-            event_map[event_id] = ret
+            # Otherwise check if we still have the event in memory.
+            event = self._event_ref.get(event_id)
+            if event:
+                # Reconstruct an event cache entry
+
+                cache_entry = EventCacheEntry(
+                    event=event,
+                    # We don't cache weakrefs to redacted events, so we know
+                    # this is None.
+                    redacted_event=None,
+                )
+                event_map[event_id] = cache_entry
+
+                # We add the entry back into the cache as we want to keep
+                # recently queried events in the cache.
+                self._get_event_cache.set((event_id,), cache_entry)
 
         return event_map
 
@@ -1124,6 +1151,10 @@ class EventsWorkerStore(SQLBaseStore):
             self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
+            if not redacted_event:
+                # We only cache references to unredacted events.
+                self._event_ref[event_id] = original_ev
+
         return result_map
 
     async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 1480a0f048..d03555a585 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
+from synapse.storage.types import Cursor
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         self._last_user_visit_update = self._get_start_of_day()
 
     @wrap_as_background_process("read_forward_extremities")
-    async def _read_forward_extremities(self):
+    async def _read_forward_extremities(self) -> None:
         def fetch(txn):
             txn.execute(
                 """
@@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (x[0] - 1) * x[1] for x in res if x[1]
         )
 
-    async def count_daily_e2ee_messages(self):
+    async def count_daily_e2ee_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
 
-    async def count_daily_sent_e2ee_messages(self):
+    async def count_daily_sent_e2ee_messages(self) -> int:
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
@@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_sent_e2ee_messages", _count_messages
         )
 
-    async def count_daily_active_e2ee_rooms(self):
+    async def count_daily_active_e2ee_rooms(self) -> int:
         def _count(txn):
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
@@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_active_e2ee_rooms", _count
         )
 
-    async def count_daily_messages(self):
+    async def count_daily_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_messages", _count_messages)
 
-    async def count_daily_sent_messages(self):
+    async def count_daily_sent_messages(self) -> int:
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
@@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_sent_messages", _count_messages
         )
 
-    async def count_daily_active_rooms(self):
+    async def count_daily_active_rooms(self) -> int:
         def _count(txn):
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
@@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
-    def _count_users(self, txn, time_from):
+    def _count_users(self, txn: Cursor, time_from: int) -> int:
         """
         Returns number of users seen in the past time_from period
         """
@@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             ) u
         """
         txn.execute(sql, (time_from,))
-        (count,) = txn.fetchone()
+        # Mypy knows that fetchone() might return None if there are no rows.
+        # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
+        # returns exactly one row.
+        (count,) = txn.fetchone()  # type: ignore[misc]
         return count
 
     async def count_r30_users(self) -> Dict[str, int]:
@@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_r30v2_users", _count_r30v2_users
         )
 
-    def _get_start_of_day(self):
+    def _get_start_of_day(self) -> int:
         """
         Returns millisecond unixtime for start of UTC day.
         """
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index bfc85b3add..38ba91af4c 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,7 +69,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         #     event_forward_extremities
         #     event_json
         #     event_push_actions
-        #     event_reference_hashes
         #     event_relations
         #     event_search
         #     event_to_state_groups
@@ -220,7 +219,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_auth",
             "event_edges",
             "event_forward_extremities",
-            "event_reference_hashes",
             "event_relations",
             "event_search",
             "rejections",
@@ -369,7 +367,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_edges",
             "event_json",
             "event_push_actions_staging",
-            "event_reference_hashes",
             "event_relations",
             "event_to_state_groups",
             "event_auth_chains",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 484976ca6b..fe8fded88b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,7 +34,7 @@ from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 logger = logging.getLogger(__name__)
@@ -161,7 +161,9 @@ class RelationsWorkerStore(SQLBaseStore):
             if len(events) > limit and last_topo_id and last_stream_id:
                 next_key = RoomStreamToken(last_topo_id, last_stream_id)
                 if from_token:
-                    next_token = from_token.copy_and_replace("room_key", next_key)
+                    next_token = from_token.copy_and_replace(
+                        StreamKeyType.ROOM, next_key
+                    )
                 else:
                     next_token = StreamToken(
                         room_key=next_key,
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3c49e7ec98..78e0773b2a 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
 
 import attr
 
@@ -27,7 +27,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
             self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
         )
 
-    async def _background_reindex_search(self, progress, batch_size):
+    async def _background_reindex_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
@@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT stream_ordering, event_id, room_id, type, json, "
                 " origin_server_ts FROM events"
@@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         return result
 
-    async def _background_reindex_gin_search(self, progress, batch_size):
+    async def _background_reindex_gin_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """This handles old synapses which used GIST indexes, if any;
         converting them back to be GIN as per the actual schema.
         """
 
-        def create_index(conn):
+        def create_index(conn: LoggingDatabaseConnection) -> None:
             conn.rollback()
 
             # we have to set autocommit, because postgres refuses to
@@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
         )
         return 1
 
-    async def _background_reindex_search_order(self, progress, batch_size):
+    async def _background_reindex_search_order(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         if not have_added_index:
 
-            def create_index(conn):
+            def create_index(conn: LoggingDatabaseConnection) -> None:
                 conn.rollback()
                 conn.set_session(autocommit=True)
                 c = conn.cursor()
@@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
                 pg,
             )
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
             sql = (
                 "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
                 " origin_server_ts = e.origin_server_ts"
@@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
         else:
             raise Exception("Unrecognized database engine")
 
-        args.append(limit)
+        # mypy expects to append only a `str`, not an `int`
+        args.append(limit)  # type: ignore[arg-type]
 
         results = await self.db_pool.execute(
             "search_rooms", self.db_pool.cursor_to_dict, sql, *args
@@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
             A set of strings.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> Set[str]:
             highlight_words = set()
             for event in events:
                 # As a hack we simply join values of all possible keys. This is
@@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
         return await self.db_pool.runInteraction("_find_highlights", f)
 
 
-def _to_postgres_options(options_dict):
+def _to_postgres_options(options_dict: JsonDict) -> str:
     return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
 
 
-def _parse_query(database_engine, search_term):
+def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
     """Takes a plain unicode string from the user and converts it into a form
     that can be passed to database.
     We use this so that we can add prefix matching, which isn't something
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 0373af86c8..0e3a23a140 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -788,30 +788,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return None
 
     async def get_current_room_stream_token_for_room_id(
-        self, room_id: Optional[str] = None
+        self, room_id: str
     ) -> RoomStreamToken:
-        """Returns the current position of the rooms stream.
-
-        By default, it returns a live token with the current global stream
-        token. Specifying a `room_id` causes it to return a historic token with
-        the room specific topological token.
-        """
+        """Returns the current position of the rooms stream (historic token)."""
         stream_ordering = self.get_room_max_stream_ordering()
-        if room_id is None:
-            return RoomStreamToken(None, stream_ordering)
-        else:
-            topo = await self.db_pool.runInteraction(
-                "_get_max_topological_txn", self._get_max_topological_txn, room_id
-            )
-            return RoomStreamToken(topo, stream_ordering)
+        topo = await self.db_pool.runInteraction(
+            "_get_max_topological_txn", self._get_max_topological_txn, room_id
+        )
+        return RoomStreamToken(topo, stream_ordering)
 
     def get_stream_id_for_event_txn(
         self,
         txn: LoggingTransaction,
         event_id: str,
-        allow_none=False,
-    ) -> int:
-        return self.db_pool.simple_select_one_onecol_txn(
+        allow_none: bool = False,
+    ) -> Optional[int]:
+        # Type ignore: we pass keyvalues a Dict[str, str]; the function wants
+        # Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
+        return self.db_pool.simple_select_one_onecol_txn(  # type: ignore[call-overload]
             txn=txn,
             table="events",
             keyvalues={"event_id": event_id},
@@ -873,7 +867,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         rows = txn.fetchall()
-        return rows[0][0] if rows else 0
+        # An aggregate function like MAX() will always return one row per group
+        # so we can safely rely on the lookup here. For example, when a we
+        # lookup a `room_id` which does not exist, `rows` will look like
+        # `[(None,)]`
+        return rows[0][0] if rows[0][0] is not None else 0
 
     @staticmethod
     def _set_before_and_after(