summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/__init__.py8
-rw-r--r--synapse/storage/databases/main/appservice.py47
-rw-r--r--synapse/storage/databases/main/cache.py8
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py4
-rw-r--r--synapse/storage/databases/main/event_federation.py187
-rw-r--r--synapse/storage/databases/main/events.py236
-rw-r--r--synapse/storage/databases/main/events_worker.py35
-rw-r--r--synapse/storage/databases/main/metrics.py74
-rw-r--r--synapse/storage/databases/main/purge_events.py4
-rw-r--r--synapse/storage/databases/main/push_rule.py284
-rw-r--r--synapse/storage/databases/main/relations.py6
-rw-r--r--synapse/storage/databases/main/room.py45
-rw-r--r--synapse/storage/databases/main/roommember.py129
-rw-r--r--synapse/storage/databases/main/search.py33
-rw-r--r--synapse/storage/databases/main/stream.py34
15 files changed, 668 insertions, 466 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5895b89202..d545a1c002 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -26,11 +26,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
-    IdGenerator,
-    MultiWriterIdGenerator,
-    StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -155,8 +151,6 @@ class DataStore(
             ],
         )
 
-        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
-        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
         self._group_updates_id_gen = StreamIdGenerator(
             db_conn, "local_group_updates", "stream_id"
         )
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 945707b0ec..e284454b66 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -203,19 +203,29 @@ class ApplicationServiceTransactionWorkerStore(
         """Get the application service state.
 
         Args:
-            service: The service whose state to set.
+            service: The service whose state to get.
         Returns:
-            An ApplicationServiceState or none.
+            An ApplicationServiceState, or None if we have yet to attempt any
+            transactions to the AS.
         """
-        result = await self.db_pool.simple_select_one(
+        # if we have created transactions for this AS but not yet attempted to send
+        # them, we will have a row in the table with state=NULL (recording the stream
+        # positions we have processed up to).
+        #
+        # On the other hand, if we have yet to create any transactions for this AS at
+        # all, then there will be no row for the AS.
+        #
+        # In either case, we return None to indicate "we don't yet know the state of
+        # this AS".
+        result = await self.db_pool.simple_select_one_onecol(
             "application_services_state",
             {"as_id": service.id},
-            ["state"],
+            retcol="state",
             allow_none=True,
             desc="get_appservice_state",
         )
         if result:
-            return ApplicationServiceState(result.get("state"))
+            return ApplicationServiceState(result)
         return None
 
     async def set_appservice_state(
@@ -296,14 +306,6 @@ class ApplicationServiceTransactionWorkerStore(
         """
 
         def _complete_appservice_txn(txn: LoggingTransaction) -> None:
-            # Set current txn_id for AS to 'txn_id'
-            self.db_pool.simple_upsert_txn(
-                txn,
-                "application_services_state",
-                {"as_id": service.id},
-                {"last_txn": txn_id},
-            )
-
             # Delete txn
             self.db_pool.simple_delete_txn(
                 txn,
@@ -452,16 +454,15 @@ class ApplicationServiceTransactionWorkerStore(
                 % (stream_type,)
             )
 
-        def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
-            stream_id_type = "%s_stream_id" % stream_type
-            txn.execute(
-                "UPDATE application_services_state SET %s = ? WHERE as_id=?"
-                % stream_id_type,
-                (pos, service.id),
-            )
-
-        await self.db_pool.runInteraction(
-            "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
+        # this may be the first time that we're recording any state for this AS, so
+        # we don't yet know if a row for it exists; hence we have to upsert here.
+        await self.db_pool.simple_upsert(
+            table="application_services_state",
+            keyvalues={"as_id": service.id},
+            values={f"{stream_type}_stream_id": pos},
+            # no need to lock when emulating upsert: as_id is a unique key
+            lock=False,
+            desc="set_appservice_stream_type_pos",
         )
 
 
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index dd4e83a2ad..1653a6a9b6 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -57,6 +57,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self._instance_name = hs.get_instance_name()
 
+        self.db_pool.updates.register_background_index_update(
+            update_name="cache_invalidation_index_by_instance",
+            index_name="cache_invalidation_stream_by_instance_instance_index",
+            table="cache_invalidation_stream_by_instance",
+            columns=("instance_name", "stream_id"),
+            psql_only=True,  # The table is only on postgres DBs.
+        )
+
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b789a588a5..af59be6b48 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -21,7 +21,7 @@ from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import LoggingTransaction
-from synapse.types import JsonDict, JsonSerializable
+from synapse.types import JsonDict, JsonSerializable, StreamKeyType
 from synapse.util import json_encoder
 
 
@@ -126,7 +126,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                     "message": "Set room key",
                     "room_id": room_id,
                     "session_id": session_id,
-                    "room_key": room_key,
+                    StreamKeyType.ROOM: room_key,
                 }
             )
 
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4710224708..dcfe8caf47 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,17 @@
 import itertools
 import logging
 from queue import Empty, PriorityQueue
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
 
 import attr
 from prometheus_client import Counter, Gauge
@@ -33,7 +43,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
@@ -135,7 +145,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # Check if we have indexed the room so we can use the chain cover
         # algorithm.
-        room = await self.get_room(room_id)
+        room = await self.get_room(room_id)  # type: ignore[attr-defined]
         if room["has_auth_chain_index"]:
             try:
                 return await self.db_pool.runInteraction(
@@ -158,7 +168,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     def _get_auth_chain_ids_using_cover_index_txn(
-        self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        event_ids: Collection[str],
+        include_given: bool,
     ) -> Set[str]:
         """Calculates the auth chain IDs using the chain index."""
 
@@ -215,9 +229,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         chains: Dict[int, int] = {}
 
         # Add all linked chains reachable from initial set of chains.
-        for batch in batch_iter(event_chains, 1000):
+        for batch2 in batch_iter(event_chains, 1000):
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "origin_chain_id", batch
+                txn.database_engine, "origin_chain_id", batch2
             )
             txn.execute(sql % (clause,), args)
 
@@ -297,7 +311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         front = set(event_ids)
         while front:
-            new_front = set()
+            new_front: Set[str] = set()
             for chunk in batch_iter(front, 100):
                 # Pull the auth events either from the cache or DB.
                 to_fetch: List[str] = []  # Event IDs to fetch from DB
@@ -316,7 +330,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
                     # Note we need to batch up the results by event ID before
                     # adding to the cache.
-                    to_cache = {}
+                    to_cache: Dict[str, List[Tuple[str, int]]] = {}
                     for event_id, auth_event_id, auth_event_depth in txn:
                         to_cache.setdefault(event_id, []).append(
                             (auth_event_id, auth_event_depth)
@@ -349,7 +363,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # Check if we have indexed the room so we can use the chain cover
         # algorithm.
-        room = await self.get_room(room_id)
+        room = await self.get_room(room_id)  # type: ignore[attr-defined]
         if room["has_auth_chain_index"]:
             try:
                 return await self.db_pool.runInteraction(
@@ -370,7 +384,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     def _get_auth_chain_difference_using_cover_index_txn(
-        self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+        self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
     ) -> Set[str]:
         """Calculates the auth chain difference using the chain index.
 
@@ -444,9 +458,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # (We need to take a copy of `seen_chains` as we want to mutate it in
         # the loop)
-        for batch in batch_iter(set(seen_chains), 1000):
+        for batch2 in batch_iter(set(seen_chains), 1000):
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "origin_chain_id", batch
+                txn.database_engine, "origin_chain_id", batch2
             )
             txn.execute(sql % (clause,), args)
 
@@ -529,7 +543,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         return result
 
     def _get_auth_chain_difference_txn(
-        self, txn, state_sets: List[Set[str]]
+        self, txn: LoggingTransaction, state_sets: List[Set[str]]
     ) -> Set[str]:
         """Calculates the auth chain difference using a breadth first search.
 
@@ -602,7 +616,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             # I think building a temporary list with fetchall is more efficient than
             # just `search.extend(txn)`, but this is unconfirmed
-            search.extend(txn.fetchall())
+            search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
 
         # sort by depth
         search.sort()
@@ -645,7 +659,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 # We parse the results and add the to the `found` set and the
                 # cache (note we need to batch up the results by event ID before
                 # adding to the cache).
-                to_cache = {}
+                to_cache: Dict[str, List[Tuple[str, int]]] = {}
                 for event_id, auth_event_id, auth_event_depth in txn:
                     to_cache.setdefault(event_id, []).append(
                         (auth_event_id, auth_event_depth)
@@ -696,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
     async def get_oldest_event_ids_with_depth_in_room(
-        self, room_id
+        self, room_id: str
     ) -> List[Tuple[str, int]]:
         """Gets the oldest events(backwards extremities) in the room along with the
         aproximate depth.
@@ -713,7 +727,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             List of (event_id, depth) tuples
         """
 
-        def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
+        def get_oldest_event_ids_with_depth_in_room_txn(
+            txn: LoggingTransaction, room_id: str
+        ) -> List[Tuple[str, int]]:
             # Assemble a dictionary with event_id -> depth for the oldest events
             # we know of in the room. Backwards extremeties are the oldest
             # events we know of in the room but we only know of them because
@@ -743,7 +759,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             txn.execute(sql, (room_id, False))
 
-            return txn.fetchall()
+            return cast(List[Tuple[str, int]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_oldest_event_ids_with_depth_in_room",
@@ -752,7 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     async def get_insertion_event_backward_extremities_in_room(
-        self, room_id
+        self, room_id: str
     ) -> List[Tuple[str, int]]:
         """Get the insertion events we know about that we haven't backfilled yet.
 
@@ -768,7 +784,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             List of (event_id, depth) tuples
         """
 
-        def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
+        def get_insertion_event_backward_extremities_in_room_txn(
+            txn: LoggingTransaction, room_id: str
+        ) -> List[Tuple[str, int]]:
             sql = """
                 SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
                 /* We only want insertion events that are also marked as backwards extremities */
@@ -780,7 +798,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             """
 
             txn.execute(sql, (room_id,))
-            return txn.fetchall()
+            return cast(List[Tuple[str, int]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_insertion_event_backward_extremities_in_room",
@@ -788,7 +806,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
-    async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+    async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
         """Returns the event ID and depth for the event that has the max depth from a set of event IDs
 
         Args:
@@ -817,7 +835,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             return max_depth_event_id, current_max_depth
 
-    async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+    async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
         """Returns the event ID and depth for the event that has the min depth from a set of event IDs
 
         Args:
@@ -865,7 +883,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
         )
 
-    def _get_prev_events_for_room_txn(self, txn, room_id: str):
+    def _get_prev_events_for_room_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> List[str]:
         # we just use the 10 newest events. Older events will become
         # prev_events of future events.
 
@@ -896,7 +916,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             sorted by extremity count.
         """
 
-        def _get_rooms_with_many_extremities_txn(txn):
+        def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
             where_clause = "1=1"
             if room_id_filter:
                 where_clause = "room_id NOT IN (%s)" % (
@@ -937,7 +957,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             "get_min_depth", self._get_min_depth_interaction, room_id
         )
 
-    def _get_min_depth_interaction(self, txn, room_id):
+    def _get_min_depth_interaction(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> Optional[int]:
         min_depth = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="room_depth",
@@ -966,22 +988,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         """
         # We want to make the cache more effective, so we clamp to the last
         # change before the given ordering.
-        last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+        last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)  # type: ignore[attr-defined]
 
         # We don't always have a full stream_to_exterm_id table, e.g. after
         # the upgrade that introduced it, so we make sure we never ask for a
         # stream_ordering from before a restart
-        last_change = max(self._stream_order_on_start, last_change)
+        last_change = max(self._stream_order_on_start, last_change)  # type: ignore[attr-defined]
 
         # provided the last_change is recent enough, we now clamp the requested
         # stream_ordering to it.
-        if last_change > self.stream_ordering_month_ago:
+        if last_change > self.stream_ordering_month_ago:  # type: ignore[attr-defined]
             stream_ordering = min(last_change, stream_ordering)
 
         return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
 
     @cached(max_entries=5000, num_args=2)
-    async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
+    async def _get_forward_extremeties_for_room(
+        self, room_id: str, stream_ordering: int
+    ) -> List[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -989,7 +1013,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         stream_orderings from that point.
         """
 
-        if stream_ordering <= self.stream_ordering_month_ago:
+        if stream_ordering <= self.stream_ordering_month_ago:  # type: ignore[attr-defined]
             raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
 
         sql = """
@@ -1002,7 +1026,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 WHERE room_id = ?
         """
 
-        def get_forward_extremeties_for_room_txn(txn):
+        def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
             txn.execute(sql, (stream_ordering, room_id))
             return [event_id for event_id, in txn]
 
@@ -1104,8 +1128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         ]
 
     async def get_backfill_events(
-        self, room_id: str, seed_event_id_list: list, limit: int
-    ):
+        self, room_id: str, seed_event_id_list: List[str], limit: int
+    ) -> List[EventBase]:
         """Get a list of Events for a given topic that occurred before (and
         including) the events in seed_event_id_list. Return a list of max size `limit`
 
@@ -1123,10 +1147,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
         events = await self.get_events_as_list(event_ids)
         return sorted(
-            events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
+            # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
+            # But it's never None, because these events were previously persisted to the DB.
+            events,
+            key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering),  # type: ignore[operator]
         )
 
-    def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
+    def _get_backfill_events(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        seed_event_id_list: List[str],
+        limit: int,
+    ) -> Set[str]:
         """
         We want to make sure that we do a breadth-first, "depth" ordered search.
         We also handle navigating historical branches of history connected by
@@ -1139,7 +1172,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             limit,
         )
 
-        event_id_results = set()
+        event_id_results: Set[str] = set()
 
         # In a PriorityQueue, the lowest valued entries are retrieved first.
         # We're using depth as the priority in the queue and tie-break based on
@@ -1147,7 +1180,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # highest and newest-in-time message. We add events to the queue with a
         # negative depth so that we process the newest-in-time messages first
         # going backwards in time. stream_ordering follows the same pattern.
-        queue = PriorityQueue()
+        queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
 
         for seed_event_id in seed_event_id_list:
             event_lookup_result = self.db_pool.simple_select_one_txn(
@@ -1253,7 +1286,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         return event_id_results
 
-    async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
+    async def get_missing_events(
+        self,
+        room_id: str,
+        earliest_events: List[str],
+        latest_events: List[str],
+        limit: int,
+    ) -> List[EventBase]:
         ids = await self.db_pool.runInteraction(
             "get_missing_events",
             self._get_missing_events,
@@ -1264,11 +1303,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
         return await self.get_events_as_list(ids)
 
-    def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
+    def _get_missing_events(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        earliest_events: List[str],
+        latest_events: List[str],
+        limit: int,
+    ) -> List[str]:
 
         seen_events = set(earliest_events)
         front = set(latest_events) - seen_events
-        event_results = []
+        event_results: List[str] = []
 
         query = (
             "SELECT prev_event_id FROM event_edges "
@@ -1311,7 +1357,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
     @wrap_as_background_process("delete_old_forward_extrem_cache")
     async def _delete_old_forward_extrem_cache(self) -> None:
-        def _delete_old_forward_extrem_cache_txn(txn):
+        def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
             # Delete entries older than a month, while making sure we don't delete
             # the only entries for a room.
             sql = """
@@ -1324,7 +1370,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 ) AND stream_ordering < ?
             """
             txn.execute(
-                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)  # type: ignore[attr-defined]
             )
 
         await self.db_pool.runInteraction(
@@ -1382,7 +1428,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         """
         if self.db_pool.engine.supports_returning:
 
-            def _remove_received_event_from_staging_txn(txn):
+            def _remove_received_event_from_staging_txn(
+                txn: LoggingTransaction,
+            ) -> Optional[int]:
                 sql = """
                     DELETE FROM federation_inbound_events_staging
                     WHERE origin = ? AND event_id = ?
@@ -1390,21 +1438,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 """
 
                 txn.execute(sql, (origin, event_id))
-                return txn.fetchone()
+                row = cast(Optional[Tuple[int]], txn.fetchone())
 
-            row = await self.db_pool.runInteraction(
+                if row is None:
+                    return None
+
+                return row[0]
+
+            return await self.db_pool.runInteraction(
                 "remove_received_event_from_staging",
                 _remove_received_event_from_staging_txn,
                 db_autocommit=True,
             )
-            if row is None:
-                return None
-
-            return row[0]
 
         else:
 
-            def _remove_received_event_from_staging_txn(txn):
+            def _remove_received_event_from_staging_txn(
+                txn: LoggingTransaction,
+            ) -> Optional[int]:
                 received_ts = self.db_pool.simple_select_one_onecol_txn(
                     txn,
                     table="federation_inbound_events_staging",
@@ -1437,7 +1488,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     ) -> Optional[Tuple[str, str]]:
         """Get the next event ID in the staging area for the given room."""
 
-        def _get_next_staged_event_id_for_room_txn(txn):
+        def _get_next_staged_event_id_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Tuple[str, str]]:
             sql = """
                 SELECT origin, event_id
                 FROM federation_inbound_events_staging
@@ -1448,7 +1501,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             txn.execute(sql, (room_id,))
 
-            return txn.fetchone()
+            return cast(Optional[Tuple[str, str]], txn.fetchone())
 
         return await self.db_pool.runInteraction(
             "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
@@ -1461,7 +1514,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     ) -> Optional[Tuple[str, EventBase]]:
         """Get the next event in the staging area for the given room."""
 
-        def _get_next_staged_event_for_room_txn(txn):
+        def _get_next_staged_event_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Tuple[str, str, str]]:
             sql = """
                 SELECT event_json, internal_metadata, origin
                 FROM federation_inbound_events_staging
@@ -1471,7 +1526,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             """
             txn.execute(sql, (room_id,))
 
-            return txn.fetchone()
+            return cast(Optional[Tuple[str, str, str]], txn.fetchone())
 
         row = await self.db_pool.runInteraction(
             "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
@@ -1599,18 +1654,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     @wrap_as_background_process("_get_stats_for_federation_staging")
-    async def _get_stats_for_federation_staging(self):
+    async def _get_stats_for_federation_staging(self) -> None:
         """Update the prometheus metrics for the inbound federation staging area."""
 
-        def _get_stats_for_federation_staging_txn(txn):
+        def _get_stats_for_federation_staging_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[int, int]:
             txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
 
             txn.execute(
                 "SELECT min(received_ts) FROM federation_inbound_events_staging"
             )
 
-            (received_ts,) = txn.fetchone()
+            (received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
 
             # If there is nothing in the staging area default it to 0.
             age = 0
@@ -1651,19 +1708,21 @@ class EventFederationStore(EventFederationWorkerStore):
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
         )
 
-    async def clean_room_for_join(self, room_id):
-        return await self.db_pool.runInteraction(
+    async def clean_room_for_join(self, room_id: str) -> None:
+        await self.db_pool.runInteraction(
             "clean_room_for_join", self._clean_room_for_join_txn, room_id
         )
 
-    def _clean_room_for_join_txn(self, txn, room_id):
+    def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
 
         txn.execute(query, (room_id,))
         txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
 
-    async def _background_delete_non_state_event_auth(self, progress, batch_size):
-        def delete_event_auth(txn):
+    async def _background_delete_non_state_event_auth(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def delete_event_auth(txn: LoggingTransaction) -> bool:
             target_min_stream_id = progress.get("target_min_stream_id_inclusive")
             max_stream_id = progress.get("max_stream_id_exclusive")
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2c86a870cf..0df8ff5395 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -36,9 +36,8 @@ from prometheus_client import Counter
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
-from synapse.crypto.event_signing import compute_event_reference_hash
-from synapse.events import EventBase  # noqa: F401
-from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.events import EventBase, relation_from_event
+from synapse.events.snapshot import EventContext
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -50,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines.postgres import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 from synapse.util.stringutils import non_null_str_or_none
@@ -130,7 +129,6 @@ class PersistEventsStore:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         *,
-        current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremities: Dict[str, Set[str]],
         use_negative_stream_ordering: bool = False,
@@ -141,8 +139,6 @@ class PersistEventsStore:
 
         Args:
             events_and_contexts:
-            current_state_for_room: Map from room_id to the current state of
-                the room based on forward extremities
             state_delta_for_room: Map from room_id to the delta to apply to
                 room state
             new_forward_extremities: Map from room_id to set of event IDs
@@ -217,9 +213,6 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-            for room_id, new_state in current_state_for_room.items():
-                self.store.get_current_state_ids.prefill((room_id,), new_state)
-
             for room_id, latest_event_ids in new_forward_extremities.items():
                 self.store.get_latest_event_ids_in_room.prefill(
                     (room_id,), list(latest_event_ids)
@@ -237,7 +230,9 @@ class PersistEventsStore:
         """
         results: List[str] = []
 
-        def _get_events_which_are_prevs_txn(txn, batch):
+        def _get_events_which_are_prevs_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             sql = """
             SELECT prev_event_id, internal_metadata
             FROM event_edges
@@ -287,7 +282,9 @@ class PersistEventsStore:
         # and their prev events.
         existing_prevs = set()
 
-        def _get_prevs_before_rejected_txn(txn, batch):
+        def _get_prevs_before_rejected_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             to_recursively_check = batch
 
             while to_recursively_check:
@@ -517,7 +514,7 @@ class PersistEventsStore:
     @classmethod
     def _add_chain_cover_index(
         cls,
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -811,7 +808,7 @@ class PersistEventsStore:
 
     @staticmethod
     def _allocate_chain_ids(
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -945,7 +942,7 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
         to_insert = []
@@ -999,7 +996,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         state_delta_by_room: Dict[str, DeltaState],
         stream_id: int,
-    ):
+    ) -> None:
         for room_id, delta_state in state_delta_by_room.items():
             to_delete = delta_state.to_delete
             to_insert = delta_state.to_insert
@@ -1157,7 +1154,7 @@ class PersistEventsStore:
                 txn, room_id, members_changed
             )
 
-    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
         events.
 
@@ -1191,7 +1188,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         new_forward_extremities: Dict[str, Set[str]],
         max_stream_order: int,
-    ):
+    ) -> None:
         for room_id in new_forward_extremities.keys():
             self.db_pool.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@@ -1256,9 +1253,9 @@ class PersistEventsStore:
 
     def _update_room_depths_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Update min_depth for each room
 
         Args:
@@ -1387,7 +1384,7 @@ class PersistEventsStore:
             # nothing to do here
             return
 
-        def event_dict(event):
+        def event_dict(event: EventBase) -> JsonDict:
             d = event.get_dict()
             d.pop("redacted", None)
             d.pop("redacted_because", None)
@@ -1478,18 +1475,20 @@ class PersistEventsStore:
             ),
         )
 
-    def _store_rejected_events_txn(self, txn, events_and_contexts):
+    def _store_rejected_events_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Add rows to the 'rejections' table for received events which were
         rejected
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
+            txn: db connection
+            events_and_contexts: events we are persisting
 
         Returns:
-            list[(EventBase, EventContext)] new list, without the rejected
-                events.
+            new list, without the rejected events.
         """
         # Remove the rejected events from the list now that we've added them
         # to the events table and the events_json table.
@@ -1510,7 +1509,7 @@ class PersistEventsStore:
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         all_events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool = False,
-    ):
+    ) -> None:
         """Update all the miscellaneous tables for new events
 
         Args:
@@ -1601,15 +1600,14 @@ class PersistEventsStore:
             inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
-        # Insert event_reference_hashes table.
-        self._store_event_reference_hashes_txn(
-            txn, [event for event, _ in events_and_contexts]
-        )
-
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
-    def _add_to_cache(self, txn, events_and_contexts):
+    def _add_to_cache(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         to_prefill = []
 
         rows = []
@@ -1640,7 +1638,7 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        def prefill():
+        def prefill() -> None:
             for cache_entry in to_prefill:
                 self.store._get_event_cache.set(
                     (cache_entry.event.event_id,), cache_entry
@@ -1670,19 +1668,24 @@ class PersistEventsStore:
         )
 
     def insert_labels_for_event_txn(
-        self, txn, event_id, labels, room_id, topological_ordering
-    ):
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        labels: List[str],
+        room_id: str,
+        topological_ordering: int,
+    ) -> None:
         """Store the mapping between an event's ID and its labels, with one row per
         (event_id, label) tuple.
 
         Args:
-            txn (LoggingTransaction): The transaction to execute.
-            event_id (str): The event's ID.
-            labels (list[str]): A list of text labels.
-            room_id (str): The ID of the room the event was sent to.
-            topological_ordering (int): The position of the event in the room's topology.
+            txn: The transaction to execute.
+            event_id: The event's ID.
+            labels: A list of text labels.
+            room_id: The ID of the room the event was sent to.
+            topological_ordering: The position of the event in the room's topology.
         """
-        return self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_insert_many_txn(
             txn=txn,
             table="event_labels",
             keys=("event_id", "label", "room_id", "topological_ordering"),
@@ -1691,44 +1694,32 @@ class PersistEventsStore:
             ],
         )
 
-    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+    def _insert_event_expiry_txn(
+        self, txn: LoggingTransaction, event_id: str, expiry_ts: int
+    ) -> None:
         """Save the expiry timestamp associated with a given event ID.
 
         Args:
-            txn (LoggingTransaction): The database transaction to use.
-            event_id (str): The event ID the expiry timestamp is associated with.
-            expiry_ts (int): The timestamp at which to expire (delete) the event.
+            txn: The database transaction to use.
+            event_id: The event ID the expiry timestamp is associated with.
+            expiry_ts: The timestamp at which to expire (delete) the event.
         """
-        return self.db_pool.simple_insert_txn(
+        self.db_pool.simple_insert_txn(
             txn=txn,
             table="event_expiry",
             values={"event_id": event_id, "expiry_ts": expiry_ts},
         )
 
-    def _store_event_reference_hashes_txn(self, txn, events):
-        """Store a hash for a PDU
-        Args:
-            txn (cursor):
-            events (list): list of Events.
-        """
-
-        vals = []
-        for event in events:
-            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
-            vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
-
-        self.db_pool.simple_insert_many_txn(
-            txn,
-            table="event_reference_hashes",
-            keys=("event_id", "algorithm", "hash"),
-            values=vals,
-        )
-
     def _store_room_members_txn(
-        self, txn, events, *, inhibit_local_membership_updates: bool = False
-    ):
+        self,
+        txn: LoggingTransaction,
+        events: List[EventBase],
+        *,
+        inhibit_local_membership_updates: bool = False,
+    ) -> None:
         """
         Store a room member in the database.
+
         Args:
             txn: The transaction to use.
             events: List of events to store.
@@ -1765,6 +1756,7 @@ class PersistEventsStore:
         )
 
         for event in events:
+            assert event.internal_metadata.stream_ordering is not None
             txn.call_after(
                 self.store._membership_stream_cache.entity_has_changed,
                 event.state_key,
@@ -1813,55 +1805,50 @@ class PersistEventsStore:
             txn: The current database transaction.
             event: The event which might have relations.
         """
-        relation = event.content.get("m.relates_to")
+        relation = relation_from_event(event)
         if not relation:
-            # No relations
-            return
-
-        # Relations must have a type and parent event ID.
-        rel_type = relation.get("rel_type")
-        if not isinstance(rel_type, str):
+            # No relation, nothing to do.
             return
 
-        parent_id = relation.get("event_id")
-        if not isinstance(parent_id, str):
-            return
-
-        # Annotations have a key field.
-        aggregation_key = None
-        if rel_type == RelationTypes.ANNOTATION:
-            aggregation_key = relation.get("key")
-
         self.db_pool.simple_insert_txn(
             txn,
             table="event_relations",
             values={
                 "event_id": event.event_id,
-                "relates_to_id": parent_id,
-                "relation_type": rel_type,
-                "aggregation_key": aggregation_key,
+                "relates_to_id": relation.parent_id,
+                "relation_type": relation.rel_type,
+                "aggregation_key": relation.aggregation_key,
             },
         )
 
-        txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
         txn.call_after(
-            self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
+            self.store.get_relations_for_event.invalidate, (relation.parent_id,)
+        )
+        txn.call_after(
+            self.store.get_aggregation_groups_for_event.invalidate,
+            (relation.parent_id,),
         )
 
-        if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.REPLACE:
+            txn.call_after(
+                self.store.get_applicable_edit.invalidate, (relation.parent_id,)
+            )
 
-        if rel_type == RelationTypes.THREAD:
-            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.THREAD:
+            txn.call_after(
+                self.store.get_thread_summary.invalidate, (relation.parent_id,)
+            )
             # It should be safe to only invalidate the cache if the user has not
             # previously participated in the thread, but that's difficult (and
             # potentially error-prone) so it is always invalidated.
             txn.call_after(
                 self.store.get_thread_participated.invalidate,
-                (parent_id, event.sender),
+                (relation.parent_id, event.sender),
             )
 
-    def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_insertion_event(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         """Handles keeping track of insertion events and edges/connections.
         Part of MSC2716.
 
@@ -1922,7 +1909,7 @@ class PersistEventsStore:
                 },
             )
 
-    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
         """Handles inserting the batch edges/connections between the batch event
         and an insertion event. Part of MSC2716.
 
@@ -2022,25 +2009,29 @@ class PersistEventsStore:
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
 
-    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("topic"), str):
             self.store_event_search_txn(
                 txn, event, "content.topic", event.content["topic"]
             )
 
-    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("name"), str):
             self.store_event_search_txn(
                 txn, event, "content.name", event.content["name"]
             )
 
-    def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_message_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if isinstance(event.content.get("body"), str):
             self.store_event_search_txn(
                 txn, event, "content.body", event.content["body"]
             )
 
-    def _store_retention_policy_for_room_txn(self, txn, event):
+    def _store_retention_policy_for_room_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if not event.is_state():
             logger.debug("Ignoring non-state m.room.retention event")
             return
@@ -2100,8 +2091,11 @@ class PersistEventsStore:
         )
 
     def _set_push_actions_for_event_and_users_txn(
-        self, txn, events_and_contexts, all_events_and_contexts
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        all_events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         """Handles moving push actions from staging table to main
         event_push_actions table for all events in `events_and_contexts`.
 
@@ -2109,12 +2103,10 @@ class PersistEventsStore:
         from the push action staging area.
 
         Args:
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            all_events_and_contexts (list[(EventBase, EventContext)]): all
-                events that we were going to persist. This includes events
-                we've already persisted, etc, that wouldn't appear in
-                events_and_context.
+            events_and_contexts: events we are persisting
+            all_events_and_contexts: all events that we were going to persist.
+                This includes events we've already persisted, etc, that wouldn't
+                appear in events_and_context.
         """
 
         # Only non outlier events will have push actions associated with them,
@@ -2183,7 +2175,9 @@ class PersistEventsStore:
             ),
         )
 
-    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+    def _remove_push_actions_for_event_id_txn(
+        self, txn: LoggingTransaction, room_id: str, event_id: str
+    ) -> None:
         # Sad that we have to blow away the cache for the whole room here
         txn.call_after(
             self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@@ -2194,7 +2188,9 @@ class PersistEventsStore:
             (room_id, event_id),
         )
 
-    def _store_rejections_txn(self, txn, event_id, reason):
+    def _store_rejections_txn(
+        self, txn: LoggingTransaction, event_id: str, reason: str
+    ) -> None:
         self.db_pool.simple_insert_txn(
             txn,
             table="rejections",
@@ -2206,8 +2202,10 @@ class PersistEventsStore:
         )
 
     def _store_event_state_mappings_txn(
-        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+    ) -> None:
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
@@ -2264,7 +2262,9 @@ class PersistEventsStore:
                 state_group_id,
             )
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+    def _update_min_depth_for_room_txn(
+        self, txn: LoggingTransaction, room_id: str, depth: int
+    ) -> None:
         min_depth = self.store._get_min_depth_interaction(txn, room_id)
 
         if min_depth is not None and depth >= min_depth:
@@ -2277,7 +2277,9 @@ class PersistEventsStore:
             values={"min_depth": depth},
         )
 
-    def _handle_mult_prev_events(self, txn, events):
+    def _handle_mult_prev_events(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """
         For the given event, update the event edges table and forward and
         backward extremities tables.
@@ -2295,7 +2297,9 @@ class PersistEventsStore:
 
         self._update_backward_extremeties(txn, events)
 
-    def _update_backward_extremeties(self, txn, events):
+    def _update_backward_extremeties(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """Updates the event_backward_extremities tables based on the new/updated
         events being persisted.
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a4a604a499..5b22d6b452 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -14,6 +14,7 @@
 
 import logging
 import threading
+import weakref
 from enum import Enum, auto
 from typing import (
     TYPE_CHECKING,
@@ -23,6 +24,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    MutableMapping,
     Optional,
     Set,
     Tuple,
@@ -248,6 +250,12 @@ class EventsWorkerStore(SQLBaseStore):
             str, ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = {}
 
+        # We keep track of the events we have currently loaded in memory so that
+        # we can reuse them even if they've been evicted from the cache. We only
+        # track events that don't need redacting in here (as then we don't need
+        # to track redaction status).
+        self._event_ref: MutableMapping[str, EventBase] = weakref.WeakValueDictionary()
+
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list: List[
             Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
@@ -723,6 +731,8 @@ class EventsWorkerStore(SQLBaseStore):
 
     def _invalidate_get_event_cache(self, event_id: str) -> None:
         self._get_event_cache.invalidate((event_id,))
+        self._event_ref.pop(event_id, None)
+        self._current_event_fetches.pop(event_id, None)
 
     def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
@@ -738,13 +748,30 @@ class EventsWorkerStore(SQLBaseStore):
         event_map = {}
 
         for event_id in events:
+            # First check if it's in the event cache
             ret = self._get_event_cache.get(
                 (event_id,), None, update_metrics=update_metrics
             )
-            if not ret:
+            if ret:
+                event_map[event_id] = ret
                 continue
 
-            event_map[event_id] = ret
+            # Otherwise check if we still have the event in memory.
+            event = self._event_ref.get(event_id)
+            if event:
+                # Reconstruct an event cache entry
+
+                cache_entry = EventCacheEntry(
+                    event=event,
+                    # We don't cache weakrefs to redacted events, so we know
+                    # this is None.
+                    redacted_event=None,
+                )
+                event_map[event_id] = cache_entry
+
+                # We add the entry back into the cache as we want to keep
+                # recently queried events in the cache.
+                self._get_event_cache.set((event_id,), cache_entry)
 
         return event_map
 
@@ -1124,6 +1151,10 @@ class EventsWorkerStore(SQLBaseStore):
             self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
+            if not redacted_event:
+                # We only cache references to unredacted events.
+                self._event_ref[event_id] = original_ev
+
         return result_map
 
     async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 1480a0f048..14294a0bb8 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -14,12 +14,16 @@
 import calendar
 import logging
 import time
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, List, Tuple, cast
 
 from synapse.metrics import GaugeBucketCollector
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
@@ -71,8 +75,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         self._last_user_visit_update = self._get_start_of_day()
 
     @wrap_as_background_process("read_forward_extremities")
-    async def _read_forward_extremities(self):
-        def fetch(txn):
+    async def _read_forward_extremities(self) -> None:
+        def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
             txn.execute(
                 """
                 SELECT t1.c, t2.c
@@ -85,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
                 ) t2 ON t1.room_id = t2.room_id
                 """
             )
-            return txn.fetchall()
+            return cast(List[Tuple[int, int]], txn.fetchall())
 
         res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
 
@@ -95,7 +99,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (x[0] - 1) * x[1] for x in res if x[1]
         )
 
-    async def count_daily_e2ee_messages(self):
+    async def count_daily_e2ee_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -103,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         call to this function, it will return None.
         """
 
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
 
-    async def count_daily_sent_e2ee_messages(self):
-        def _count_messages(txn):
+    async def count_daily_sent_e2ee_messages(self) -> int:
+        def _count_messages(txn: LoggingTransaction) -> int:
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
             like_clause = "%:" + self.hs.hostname
@@ -129,29 +133,29 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             """
 
             txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
             "count_daily_sent_e2ee_messages", _count_messages
         )
 
-    async def count_daily_active_e2ee_rooms(self):
-        def _count(txn):
+    async def count_daily_active_e2ee_rooms(self) -> int:
+        def _count(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
             "count_daily_active_e2ee_rooms", _count
         )
 
-    async def count_daily_messages(self):
+    async def count_daily_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -159,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         call to this function, it will return None.
         """
 
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_messages", _count_messages)
 
-    async def count_daily_sent_messages(self):
-        def _count_messages(txn):
+    async def count_daily_sent_messages(self) -> int:
+        def _count_messages(txn: LoggingTransaction) -> int:
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
             like_clause = "%:" + self.hs.hostname
@@ -185,22 +189,22 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             """
 
             txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
             "count_daily_sent_messages", _count_messages
         )
 
-    async def count_daily_active_rooms(self):
-        def _count(txn):
+    async def count_daily_active_rooms(self) -> int:
+        def _count(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
@@ -226,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
-    def _count_users(self, txn, time_from):
+    def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
         """
         Returns number of users seen in the past time_from period
         """
@@ -238,7 +242,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             ) u
         """
         txn.execute(sql, (time_from,))
-        (count,) = txn.fetchone()
+        # Mypy knows that fetchone() might return None if there are no rows.
+        # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
+        # returns exactly one row.
+        (count,) = cast(Tuple[int], txn.fetchone())
         return count
 
     async def count_r30_users(self) -> Dict[str, int]:
@@ -252,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
              A mapping of counts globally as well as broken out by platform.
         """
 
-        def _count_r30_users(txn):
+        def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
             thirty_days_in_secs = 86400 * 30
             now = int(self._clock.time())
             thirty_days_ago_in_secs = now - thirty_days_in_secs
@@ -317,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
             txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
 
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             results["all"] = count
 
             return results
@@ -344,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
               - "web" (any web application -- it's not possible to distinguish Element Web here)
         """
 
-        def _count_r30v2_users(txn):
+        def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
             thirty_days_in_secs = 86400 * 30
             now = int(self._clock.time())
             sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
@@ -441,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
                     thirty_days_in_secs * 1000,
                 ),
             )
-            row = txn.fetchone()
-            if row is None:
-                results["all"] = 0
-            else:
-                results["all"] = row[0]
+            (count,) = cast(Tuple[int], txn.fetchone())
+            results["all"] = count
 
             return results
 
@@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_r30v2_users", _count_r30v2_users
         )
 
-    def _get_start_of_day(self):
+    def _get_start_of_day(self) -> int:
         """
         Returns millisecond unixtime for start of UTC day.
         """
@@ -467,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         Generates daily visit data for use in cohort/ retention analysis
         """
 
-        def _generate_user_daily_visits(txn):
+        def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
             logger.info("Calling _generate_user_daily_visits")
             today_start = self._get_start_of_day()
             a_day_in_milliseconds = 24 * 60 * 60 * 1000
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index bfc85b3add..c94d5f9f81 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,7 +69,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         #     event_forward_extremities
         #     event_json
         #     event_push_actions
-        #     event_reference_hashes
         #     event_relations
         #     event_search
         #     event_to_state_groups
@@ -220,7 +219,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_auth",
             "event_edges",
             "event_forward_extremities",
-            "event_reference_hashes",
             "event_relations",
             "event_search",
             "rejections",
@@ -369,7 +367,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_edges",
             "event_json",
             "event_push_actions_staging",
-            "event_reference_hashes",
             "event_relations",
             "event_to_state_groups",
             "event_auth_chains",
@@ -420,6 +417,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "room_account_data",
             "room_tags",
             "local_current_membership",
+            "federation_inbound_events_staging",
         ):
             logger.info("[purge] removing %s from %s", room_id, table)
             txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 4ed913e248..ad67901cc1 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,14 +14,18 @@
 # limitations under the License.
 import abc
 import logging
-from typing import TYPE_CHECKING, Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
 
 from synapse.api.errors import StoreError
 from synapse.config.homeserver import ExperimentalConfig
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.pusher import PusherWorkerStore
@@ -30,9 +34,12 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
     AbstractStreamIdTracker,
+    IdGenerator,
     StreamIdGenerator,
 )
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -57,7 +64,11 @@ def _is_experimental_rule_enabled(
     return True
 
 
-def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig):
+def _load_rules(
+    rawrules: List[JsonDict],
+    enabled_map: Dict[str, bool],
+    experimental_config: ExperimentalConfig,
+) -> List[JsonDict]:
     ruleslist = []
     for rawrule in rawrules:
         rule = dict(rawrule)
@@ -137,7 +148,7 @@ class PushRulesWorkerStore(
         )
 
     @abc.abstractmethod
-    def get_max_push_rules_stream_id(self):
+    def get_max_push_rules_stream_id(self) -> int:
         """Get the position of the push rules stream.
 
         Returns:
@@ -146,7 +157,7 @@ class PushRulesWorkerStore(
         raise NotImplementedError()
 
     @cached(max_entries=5000)
-    async def get_push_rules_for_user(self, user_id):
+    async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
         rows = await self.db_pool.simple_select_list(
             table="push_rules",
             keyvalues={"user_name": user_id},
@@ -168,7 +179,7 @@ class PushRulesWorkerStore(
         return _load_rules(rows, enabled_map, self.hs.config.experimental)
 
     @cached(max_entries=5000)
-    async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
+    async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
         results = await self.db_pool.simple_select_list(
             table="push_rules_enable",
             keyvalues={"user_name": user_id},
@@ -184,13 +195,13 @@ class PushRulesWorkerStore(
             return False
         else:
 
-            def have_push_rules_changed_txn(txn):
+            def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool:
                 sql = (
                     "SELECT COUNT(stream_id) FROM push_rules_stream"
                     " WHERE user_id = ? AND ? < stream_id"
                 )
                 txn.execute(sql, (user_id, last_id))
-                (count,) = txn.fetchone()
+                (count,) = cast(Tuple[int], txn.fetchone())
                 return bool(count)
 
             return await self.db_pool.runInteraction(
@@ -202,11 +213,13 @@ class PushRulesWorkerStore(
         list_name="user_ids",
         num_args=1,
     )
-    async def bulk_get_push_rules(self, user_ids):
+    async def bulk_get_push_rules(
+        self, user_ids: Collection[str]
+    ) -> Dict[str, List[JsonDict]]:
         if not user_ids:
             return {}
 
-        results = {user_id: [] for user_id in user_ids}
+        results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
 
         rows = await self.db_pool.simple_select_many_batch(
             table="push_rules",
@@ -230,67 +243,18 @@ class PushRulesWorkerStore(
 
         return results
 
-    async def copy_push_rule_from_room_to_room(
-        self, new_room_id: str, user_id: str, rule: dict
-    ) -> None:
-        """Copy a single push rule from one room to another for a specific user.
-
-        Args:
-            new_room_id: ID of the new room.
-            user_id : ID of user the push rule belongs to.
-            rule: A push rule.
-        """
-        # Create new rule id
-        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
-        new_rule_id = rule_id_scope + "/" + new_room_id
-
-        # Change room id in each condition
-        for condition in rule.get("conditions", []):
-            if condition.get("key") == "room_id":
-                condition["pattern"] = new_room_id
-
-        # Add the rule for the new room
-        await self.add_push_rule(
-            user_id=user_id,
-            rule_id=new_rule_id,
-            priority_class=rule["priority_class"],
-            conditions=rule["conditions"],
-            actions=rule["actions"],
-        )
-
-    async def copy_push_rules_from_room_to_room_for_user(
-        self, old_room_id: str, new_room_id: str, user_id: str
-    ) -> None:
-        """Copy all of the push rules from one room to another for a specific
-        user.
-
-        Args:
-            old_room_id: ID of the old room.
-            new_room_id: ID of the new room.
-            user_id: ID of user to copy push rules for.
-        """
-        # Retrieve push rules for this user
-        user_push_rules = await self.get_push_rules_for_user(user_id)
-
-        # Get rules relating to the old room and copy them to the new room
-        for rule in user_push_rules:
-            conditions = rule.get("conditions", [])
-            if any(
-                (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
-                for c in conditions
-            ):
-                await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
-
     @cachedList(
         cached_method_name="get_push_rules_enabled_for_user",
         list_name="user_ids",
         num_args=1,
     )
-    async def bulk_get_push_rules_enabled(self, user_ids):
+    async def bulk_get_push_rules_enabled(
+        self, user_ids: Collection[str]
+    ) -> Dict[str, Dict[str, bool]]:
         if not user_ids:
             return {}
 
-        results = {user_id: {} for user_id in user_ids}
+        results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
 
         rows = await self.db_pool.simple_select_many_batch(
             table="push_rules_enable",
@@ -306,7 +270,7 @@ class PushRulesWorkerStore(
 
     async def get_all_push_rule_updates(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+    ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
         """Get updates for push_rules replication stream.
 
         Args:
@@ -331,7 +295,9 @@ class PushRulesWorkerStore(
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_push_rule_updates_txn(txn):
+        def get_all_push_rule_updates_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
             sql = """
                 SELECT stream_id, user_id
                 FROM push_rules_stream
@@ -340,7 +306,10 @@ class PushRulesWorkerStore(
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
+            updates = cast(
+                List[Tuple[int, Tuple[str]]],
+                [(stream_id, (user_id,)) for stream_id, user_id in txn],
+            )
 
             limited = False
             upper_bound = current_id
@@ -356,15 +325,30 @@ class PushRulesWorkerStore(
 
 
 class PushRuleStore(PushRulesWorkerStore):
+    # Because we have write access, this will be a StreamIdGenerator
+    # (see PushRulesWorkerStore.__init__)
+    _push_rules_stream_id_gen: AbstractStreamIdGenerator
+
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+
     async def add_push_rule(
         self,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions,
-        actions,
-        before=None,
-        after=None,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        conditions: List[Dict[str, str]],
+        actions: List[Union[JsonDict, str]],
+        before: Optional[str] = None,
+        after: Optional[str] = None,
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
@@ -400,17 +384,17 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _add_push_rule_relative_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions_json,
-        actions_json,
-        before,
-        after,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        conditions_json: str,
+        actions_json: str,
+        before: str,
+        after: str,
+    ) -> None:
         # Lock the table since otherwise we'll have annoying races between the
         # SELECT here and the UPSERT below.
         self.database_engine.lock_table(txn, "push_rules")
@@ -470,15 +454,15 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _add_push_rule_highest_priority_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions_json,
-        actions_json,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        conditions_json: str,
+        actions_json: str,
+    ) -> None:
         # Lock the table since otherwise we'll have annoying races between the
         # SELECT here and the UPSERT below.
         self.database_engine.lock_table(txn, "push_rules")
@@ -510,17 +494,17 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _upsert_push_rule_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        priority,
-        conditions_json,
-        actions_json,
-        update_stream=True,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        priority: int,
+        conditions_json: str,
+        actions_json: str,
+        update_stream: bool = True,
+    ) -> None:
         """Specialised version of simple_upsert_txn that picks a push_rule_id
         using the _push_rule_id_gen if it needs to insert the rule. It assumes
         that the "push_rules" table is locked"""
@@ -600,7 +584,11 @@ class PushRuleStore(PushRulesWorkerStore):
             rule_id: The rule_id of the rule to be deleted
         """
 
-        def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+        def delete_push_rule_txn(
+            txn: LoggingTransaction,
+            stream_id: int,
+            event_stream_ordering: int,
+        ) -> None:
             # we don't use simple_delete_one_txn because that would fail if the
             # user did not have a push_rule_enable row.
             self.db_pool.simple_delete_txn(
@@ -661,14 +649,14 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _set_push_rule_enabled_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        enabled,
-        is_default_rule,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        enabled: bool,
+        is_default_rule: bool,
+    ) -> None:
         new_id = self._push_rules_enable_id_gen.get_next()
 
         if not is_default_rule:
@@ -740,7 +728,11 @@ class PushRuleStore(PushRulesWorkerStore):
         """
         actions_json = json_encoder.encode(actions)
 
-        def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
+        def set_push_rule_actions_txn(
+            txn: LoggingTransaction,
+            stream_id: int,
+            event_stream_ordering: int,
+        ) -> None:
             if is_default_rule:
                 # Add a dummy rule to the rules table with the user specified
                 # actions.
@@ -794,8 +786,15 @@ class PushRuleStore(PushRulesWorkerStore):
             )
 
     def _insert_push_rules_update_txn(
-        self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
-    ):
+        self,
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        op: str,
+        data: Optional[JsonDict] = None,
+    ) -> None:
         values = {
             "stream_id": stream_id,
             "event_stream_ordering": event_stream_ordering,
@@ -814,5 +813,56 @@ class PushRuleStore(PushRulesWorkerStore):
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
 
-    def get_max_push_rules_stream_id(self):
+    def get_max_push_rules_stream_id(self) -> int:
         return self._push_rules_stream_id_gen.get_current_token()
+
+    async def copy_push_rule_from_room_to_room(
+        self, new_room_id: str, user_id: str, rule: dict
+    ) -> None:
+        """Copy a single push rule from one room to another for a specific user.
+
+        Args:
+            new_room_id: ID of the new room.
+            user_id : ID of user the push rule belongs to.
+            rule: A push rule.
+        """
+        # Create new rule id
+        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+        new_rule_id = rule_id_scope + "/" + new_room_id
+
+        # Change room id in each condition
+        for condition in rule.get("conditions", []):
+            if condition.get("key") == "room_id":
+                condition["pattern"] = new_room_id
+
+        # Add the rule for the new room
+        await self.add_push_rule(
+            user_id=user_id,
+            rule_id=new_rule_id,
+            priority_class=rule["priority_class"],
+            conditions=rule["conditions"],
+            actions=rule["actions"],
+        )
+
+    async def copy_push_rules_from_room_to_room_for_user(
+        self, old_room_id: str, new_room_id: str, user_id: str
+    ) -> None:
+        """Copy all of the push rules from one room to another for a specific
+        user.
+
+        Args:
+            old_room_id: ID of the old room.
+            new_room_id: ID of the new room.
+            user_id: ID of user to copy push rules for.
+        """
+        # Retrieve push rules for this user
+        user_push_rules = await self.get_push_rules_for_user(user_id)
+
+        # Get rules relating to the old room and copy them to the new room
+        for rule in user_push_rules:
+            conditions = rule.get("conditions", [])
+            if any(
+                (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
+                for c in conditions
+            ):
+                await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 484976ca6b..fe8fded88b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,7 +34,7 @@ from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 logger = logging.getLogger(__name__)
@@ -161,7 +161,9 @@ class RelationsWorkerStore(SQLBaseStore):
             if len(events) > limit and last_topo_id and last_stream_id:
                 next_key = RoomStreamToken(last_topo_id, last_stream_id)
                 if from_token:
-                    next_token = from_token.copy_and_replace("room_key", next_key)
+                    next_token = from_token.copy_and_replace(
+                        StreamKeyType.ROOM, next_key
+                    )
                 else:
                     next_token = StreamToken(
                         room_key=next_key,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 87e9482c60..ded15b92ef 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -45,7 +45,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import IdGenerator
-from synapse.types import JsonDict, ThirdPartyInstanceID
+from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.stringutils import MXC_REGEX
@@ -699,7 +699,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
 
     @cached()
-    async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
+    async def get_retention_policy_for_room(self, room_id: str) -> RetentionPolicy:
         """Get the retention policy for a given room.
 
         If no retention policy has been found for this room, returns a policy defined
@@ -707,12 +707,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         the 'max_lifetime' if no default policy has been defined in the server's
         configuration).
 
+        If support for retention policies is disabled, a policy with a 'min_lifetime' and
+        'max_lifetime' of None is returned.
+
         Args:
             room_id: The ID of the room to get the retention policy of.
 
         Returns:
             A dict containing "min_lifetime" and "max_lifetime" for this room.
         """
+        # If the room retention feature is disabled, return a policy with no minimum nor
+        # maximum. This prevents incorrectly filtering out events when sending to
+        # the client.
+        if not self.config.retention.retention_enabled:
+            return RetentionPolicy()
 
         def get_retention_policy_for_room_txn(
             txn: LoggingTransaction,
@@ -736,10 +744,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         # If we don't know this room ID, ret will be None, in this case return the default
         # policy.
         if not ret:
-            return {
-                "min_lifetime": self.config.retention.retention_default_min_lifetime,
-                "max_lifetime": self.config.retention.retention_default_max_lifetime,
-            }
+            return RetentionPolicy(
+                min_lifetime=self.config.retention.retention_default_min_lifetime,
+                max_lifetime=self.config.retention.retention_default_max_lifetime,
+            )
 
         min_lifetime = ret[0]["min_lifetime"]
         max_lifetime = ret[0]["max_lifetime"]
@@ -754,10 +762,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         if max_lifetime is None:
             max_lifetime = self.config.retention.retention_default_max_lifetime
 
-        return {
-            "min_lifetime": min_lifetime,
-            "max_lifetime": max_lifetime,
-        }
+        return RetentionPolicy(
+            min_lifetime=min_lifetime,
+            max_lifetime=max_lifetime,
+        )
 
     async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
         """Retrieves all the local and remote media MXC URIs in a given room
@@ -994,7 +1002,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
     async def get_rooms_for_retention_period_in_range(
         self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
-    ) -> Dict[str, Dict[str, Optional[int]]]:
+    ) -> Dict[str, RetentionPolicy]:
         """Retrieves all of the rooms within the given retention range.
 
         Optionally includes the rooms which don't have a retention policy.
@@ -1016,7 +1024,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         def get_rooms_for_retention_period_in_range_txn(
             txn: LoggingTransaction,
-        ) -> Dict[str, Dict[str, Optional[int]]]:
+        ) -> Dict[str, RetentionPolicy]:
             range_conditions = []
             args = []
 
@@ -1047,10 +1055,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             rooms_dict = {}
 
             for row in rows:
-                rooms_dict[row["room_id"]] = {
-                    "min_lifetime": row["min_lifetime"],
-                    "max_lifetime": row["max_lifetime"],
-                }
+                rooms_dict[row["room_id"]] = RetentionPolicy(
+                    min_lifetime=row["min_lifetime"],
+                    max_lifetime=row["max_lifetime"],
+                )
 
             if include_null:
                 # If required, do a second query that retrieves all of the rooms we know
@@ -1065,10 +1073,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 # policy in its state), add it with a null policy.
                 for row in rows:
                     if row["room_id"] not in rooms_dict:
-                        rooms_dict[row["room_id"]] = {
-                            "min_lifetime": None,
-                            "max_lifetime": None,
-                        }
+                        rooms_dict[row["room_id"]] = RetentionPolicy()
 
             return rooms_dict
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 48e83592e7..cc528fcf2d 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,6 +15,7 @@
 import logging
 from typing import (
     TYPE_CHECKING,
+    Callable,
     Collection,
     Dict,
     FrozenSet,
@@ -37,7 +38,12 @@ from synapse.metrics.background_process_metrics import (
     wrap_as_background_process,
 )
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
 from synapse.storage.roommember import (
@@ -46,7 +52,7 @@ from synapse.storage.roommember import (
     ProfileInfo,
     RoomsForUser,
 )
-from synapse.types import PersistedEventPosition, get_domain_from_id
+from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -115,7 +121,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             )
 
     @wrap_as_background_process("_count_known_servers")
-    async def _count_known_servers(self):
+    async def _count_known_servers(self) -> int:
         """
         Count the servers that this server knows about.
 
@@ -123,7 +129,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         `synapse_federation_known_servers` LaterGauge to collect.
         """
 
-        def _transact(txn):
+        def _transact(txn: LoggingTransaction) -> int:
             if isinstance(self.database_engine, Sqlite3Engine):
                 query = """
                     SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
@@ -150,7 +156,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         self._known_servers_count = max([count, 1])
         return self._known_servers_count
 
-    def _check_safe_current_state_events_membership_updated_txn(self, txn):
+    def _check_safe_current_state_events_membership_updated_txn(
+        self, txn: LoggingTransaction
+    ) -> None:
         """Checks if it is safe to assume the new current_state_events
         membership column is up to date
         """
@@ -182,7 +190,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
-    def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
+    def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
         # If we can assume current_state_events.membership is up to date
         # then we can avoid a join, which is a Very Good Thing given how
         # frequently this function gets called.
@@ -222,7 +230,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             A mapping from user ID to ProfileInfo.
         """
 
-        def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
+        def _get_users_in_room_with_profiles(
+            txn: LoggingTransaction,
+        ) -> Dict[str, ProfileInfo]:
             sql = """
                 SELECT state_key, display_name, avatar_url FROM room_memberships as m
                 INNER JOIN current_state_events as c
@@ -250,7 +260,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             dict of membership states, pointing to a MemberSummary named tuple.
         """
 
-        def _get_room_summary_txn(txn):
+        def _get_room_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, MemberSummary]:
             # first get counts.
             # We do this all in one transaction to keep the cache small.
             # FIXME: get rid of this when we have room_stats
@@ -279,7 +291,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 """
 
             txn.execute(sql, (room_id,))
-            res = {}
+            res: Dict[str, MemberSummary] = {}
             for count, membership in txn:
                 res.setdefault(membership, MemberSummary([], count))
 
@@ -400,7 +412,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     def _get_rooms_for_local_user_where_membership_is_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         user_id: str,
         membership_list: List[str],
     ) -> List[RoomsForUser]:
@@ -488,7 +500,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     def _get_rooms_for_user_with_stream_ordering_txn(
-        self, txn, user_id: str
+        self, txn: LoggingTransaction, user_id: str
     ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
         # We use `current_state_events` here and not `local_current_membership`
         # as a) this gets called with remote users and b) this only gets called
@@ -542,7 +554,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     def _get_rooms_for_users_with_stream_ordering_txn(
-        self, txn, user_ids: Collection[str]
+        self, txn: LoggingTransaction, user_ids: Collection[str]
     ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
 
         clause, args = make_in_list_sql_clause(
@@ -575,7 +587,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         txn.execute(sql, [Membership.JOIN] + args)
 
-        result = {user_id: set() for user_id in user_ids}
+        result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
+            user_id: set() for user_id in user_ids
+        }
         for user_id, room_id, instance, stream_id in txn:
             result[user_id].add(
                 GetRoomsForUserWithStreamOrdering(
@@ -595,7 +609,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         if not user_ids:
             return set()
 
-        def _get_users_server_still_shares_room_with_txn(txn):
+        def _get_users_server_still_shares_room_with_txn(
+            txn: LoggingTransaction,
+        ) -> Set[str]:
             sql = """
                 SELECT state_key FROM current_state_events
                 WHERE
@@ -619,7 +635,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     async def get_rooms_for_user(
-        self, user_id: str, on_invalidate=None
+        self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
     ) -> FrozenSet[str]:
         """Returns a set of room_ids the user is currently joined to.
 
@@ -657,7 +673,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     async def get_joined_users_from_context(
         self, event: EventBase, context: EventContext
     ) -> Dict[str, ProfileInfo]:
-        state_group = context.state_group
+        state_group: Union[object, int] = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -666,14 +682,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             state_group = object()
 
         current_state_ids = await context.get_current_state_ids()
+        assert current_state_ids is not None
+        assert state_group is not None
         return await self._get_joined_users_from_context(
             event.room_id, state_group, current_state_ids, event=event, context=context
         )
 
     async def get_joined_users_from_state(
-        self, room_id, state_entry
+        self, room_id: str, state_entry: "_StateCacheEntry"
     ) -> Dict[str, ProfileInfo]:
-        state_group = state_entry.state_group
+        state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -681,6 +699,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
+        assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
             return await self._get_joined_users_from_context(
                 room_id, state_group, state_entry.state, context=state_entry
@@ -689,12 +708,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
     async def _get_joined_users_from_context(
         self,
-        room_id,
-        state_group,
-        current_state_ids,
-        cache_context,
-        event=None,
-        context=None,
+        room_id: str,
+        state_group: Union[object, int],
+        current_state_ids: StateMap[str],
+        cache_context: _CacheContext,
+        event: Optional[EventBase] = None,
+        context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
     ) -> Dict[str, ProfileInfo]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
@@ -765,14 +784,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return users_in_room
 
     @cached(max_entries=10000)
-    def _get_joined_profile_from_event_id(self, event_id):
+    def _get_joined_profile_from_event_id(
+        self, event_id: str
+    ) -> Optional[Tuple[str, ProfileInfo]]:
         raise NotImplementedError()
 
     @cachedList(
         cached_method_name="_get_joined_profile_from_event_id",
         list_name="event_ids",
     )
-    async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+    async def _get_joined_profiles_from_event_ids(
+        self, event_ids: Iterable[str]
+    ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
         """For given set of member event_ids check if they point to a join
         event and if so return the associated user and profile info.
 
@@ -780,8 +803,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             event_ids: The member event IDs to lookup
 
         Returns:
-            dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
-            to `user_id` and ProfileInfo (or None if not join event).
+            Map from event ID to `user_id` and ProfileInfo (or None if not join event).
         """
 
         rows = await self.db_pool.simple_select_many_batch(
@@ -847,8 +869,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return True
 
-    async def get_joined_hosts(self, room_id: str, state_entry):
-        state_group = state_entry.state_group
+    async def get_joined_hosts(
+        self, room_id: str, state_entry: "_StateCacheEntry"
+    ) -> FrozenSet[str]:
+        state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -856,6 +880,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
+        assert state_group is not None
         with Measure(self._clock, "get_joined_hosts"):
             return await self._get_joined_hosts(
                 room_id, state_group, state_entry=state_entry
@@ -863,7 +888,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(num_args=2, max_entries=10000, iterable=True)
     async def _get_joined_hosts(
-        self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
+        self,
+        room_id: str,
+        state_group: Union[object, int],
+        state_entry: "_StateCacheEntry",
     ) -> FrozenSet[str]:
         # We don't use `state_group`, it's there so that we can cache based on
         # it. However, its important that its never None, since two
@@ -881,7 +909,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # `get_joined_hosts` is called with the "current" state group for the
         # room, and so consecutive calls will be for consecutive state groups
         # which point to the previous state group.
-        cache = await self._get_joined_hosts_cache(room_id)
+        cache = await self._get_joined_hosts_cache(room_id)  # type: ignore[misc]
 
         # If the state group in the cache matches, we already have the data we need.
         if state_entry.state_group == cache.state_group:
@@ -897,6 +925,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             elif state_entry.prev_group == cache.state_group:
                 # The cached work is for the previous state group, so we work out
                 # the delta.
+                assert state_entry.delta_ids is not None
                 for (typ, state_key), event_id in state_entry.delta_ids.items():
                     if typ != EventTypes.Member:
                         continue
@@ -942,7 +971,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         Returns False if they have since re-joined."""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT"
                 "  COUNT(*)"
@@ -973,7 +1002,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             The forgotten rooms.
         """
 
-        def _get_forgotten_rooms_for_user_txn(txn):
+        def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
             # This is a slightly convoluted query that first looks up all rooms
             # that the user has forgotten in the past, then rechecks that list
             # to see if any have subsequently been updated. This is done so that
@@ -1076,7 +1105,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             clause,
         )
 
-        def _is_local_host_in_room_ignoring_users_txn(txn):
+        def _is_local_host_in_room_ignoring_users_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             txn.execute(sql, (room_id, Membership.JOIN, *args))
 
             return bool(txn.fetchone())
@@ -1110,15 +1141,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
             where_clause="forgotten = 1",
         )
 
-    async def _background_add_membership_profile(self, progress, batch_size):
+    async def _background_add_membership_profile(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress.get(
-            "target_min_stream_id_inclusive", self._min_stream_order_on_start
+            "target_min_stream_id_inclusive", self._min_stream_order_on_start  # type: ignore[attr-defined]
         )
         max_stream_id = progress.get(
-            "max_stream_id_exclusive", self._stream_order_on_start + 1
+            "max_stream_id_exclusive", self._stream_order_on_start + 1  # type: ignore[attr-defined]
         )
 
-        def add_membership_profile_txn(txn):
+        def add_membership_profile_txn(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
                 FROM events
@@ -1182,13 +1215,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
         return result
 
-    async def _background_current_state_membership(self, progress, batch_size):
+    async def _background_current_state_membership(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Update the new membership column on current_state_events.
 
         This works by iterating over all rooms in alphebetical order.
         """
 
-        def _background_current_state_membership_txn(txn, last_processed_room):
+        def _background_current_state_membership_txn(
+            txn: LoggingTransaction, last_processed_room: str
+        ) -> Tuple[int, bool]:
             processed = 0
             while processed < batch_size:
                 txn.execute(
@@ -1242,7 +1279,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
         return row_count
 
 
-class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
+class RoomMemberStore(
+    RoomMemberWorkerStore,
+    RoomMemberBackgroundUpdateStore,
+    CacheInvalidationWorkerStore,
+):
     def __init__(
         self,
         database: DatabasePool,
@@ -1254,7 +1295,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     async def forget(self, user_id: str, room_id: str) -> None:
         """Indicate that user_id wishes to discard history for room_id."""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             sql = (
                 "UPDATE"
                 "  room_memberships"
@@ -1288,5 +1329,5 @@ class _JoinedHostsCache:
     # equal to anything else).
     state_group: Union[object, int] = attr.Factory(object)
 
-    def __len__(self):
+    def __len__(self) -> int:
         return sum(len(v) for v in self.hosts_to_joined_users.values())
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 67d3bb2b4b..7aa7126b69 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
 
 import attr
 
@@ -27,7 +27,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
             self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
         )
 
-    async def _background_reindex_search(self, progress, batch_size):
+    async def _background_reindex_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
@@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT stream_ordering, event_id, room_id, type, json, "
                 " origin_server_ts FROM events"
@@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         return result
 
-    async def _background_reindex_gin_search(self, progress, batch_size):
+    async def _background_reindex_gin_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """This handles old synapses which used GIST indexes, if any;
         converting them back to be GIN as per the actual schema.
         """
 
-        def create_index(conn):
+        def create_index(conn: LoggingDatabaseConnection) -> None:
             conn.rollback()
 
             # we have to set autocommit, because postgres refuses to
@@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
         )
         return 1
 
-    async def _background_reindex_search_order(self, progress, batch_size):
+    async def _background_reindex_search_order(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         if not have_added_index:
 
-            def create_index(conn):
+            def create_index(conn: LoggingDatabaseConnection) -> None:
                 conn.rollback()
                 conn.set_session(autocommit=True)
                 c = conn.cursor()
@@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
                 pg,
             )
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
             sql = (
                 "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
                 " origin_server_ts = e.origin_server_ts"
@@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
         else:
             raise Exception("Unrecognized database engine")
 
-        args.append(limit)
+        # mypy expects to append only a `str`, not an `int`
+        args.append(limit)  # type: ignore[arg-type]
 
         results = await self.db_pool.execute(
             "search_rooms", self.db_pool.cursor_to_dict, sql, *args
@@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
             A set of strings.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> Set[str]:
             highlight_words = set()
             for event in events:
                 # As a hack we simply join values of all possible keys. This is
@@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
         return await self.db_pool.runInteraction("_find_highlights", f)
 
 
-def _to_postgres_options(options_dict):
+def _to_postgres_options(options_dict: JsonDict) -> str:
     return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
 
 
-def _parse_query(database_engine, search_term):
+def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
     """Takes a plain unicode string from the user and converts it into a form
     that can be passed to database.
     We use this so that we can add prefix matching, which isn't something
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 0373af86c8..0e3a23a140 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -788,30 +788,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return None
 
     async def get_current_room_stream_token_for_room_id(
-        self, room_id: Optional[str] = None
+        self, room_id: str
     ) -> RoomStreamToken:
-        """Returns the current position of the rooms stream.
-
-        By default, it returns a live token with the current global stream
-        token. Specifying a `room_id` causes it to return a historic token with
-        the room specific topological token.
-        """
+        """Returns the current position of the rooms stream (historic token)."""
         stream_ordering = self.get_room_max_stream_ordering()
-        if room_id is None:
-            return RoomStreamToken(None, stream_ordering)
-        else:
-            topo = await self.db_pool.runInteraction(
-                "_get_max_topological_txn", self._get_max_topological_txn, room_id
-            )
-            return RoomStreamToken(topo, stream_ordering)
+        topo = await self.db_pool.runInteraction(
+            "_get_max_topological_txn", self._get_max_topological_txn, room_id
+        )
+        return RoomStreamToken(topo, stream_ordering)
 
     def get_stream_id_for_event_txn(
         self,
         txn: LoggingTransaction,
         event_id: str,
-        allow_none=False,
-    ) -> int:
-        return self.db_pool.simple_select_one_onecol_txn(
+        allow_none: bool = False,
+    ) -> Optional[int]:
+        # Type ignore: we pass keyvalues a Dict[str, str]; the function wants
+        # Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
+        return self.db_pool.simple_select_one_onecol_txn(  # type: ignore[call-overload]
             txn=txn,
             table="events",
             keyvalues={"event_id": event_id},
@@ -873,7 +867,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         rows = txn.fetchall()
-        return rows[0][0] if rows else 0
+        # An aggregate function like MAX() will always return one row per group
+        # so we can safely rely on the lookup here. For example, when a we
+        # lookup a `room_id` which does not exist, `rows` will look like
+        # `[(None,)]`
+        return rows[0][0] if rows[0][0] is not None else 0
 
     @staticmethod
     def _set_before_and_after(