summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/appservice.py6
-rw-r--r--synapse/storage/databases/main/devices.py50
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py19
-rw-r--r--synapse/storage/databases/main/events.py107
-rw-r--r--synapse/storage/databases/main/events_worker.py581
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py11
-rw-r--r--synapse/storage/databases/main/registration.py28
-rw-r--r--synapse/storage/databases/main/roommember.py4
-rw-r--r--synapse/storage/databases/main/stream.py15
-rw-r--r--synapse/storage/databases/main/transactions.py70
12 files changed, 717 insertions, 180 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index baec35ee27..4a883dc166 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore(
             A list of ApplicationServices, which may be empty.
         """
         results = await self.db_pool.simple_select_list(
-            "application_services_state", {"state": state}, ["as_id"]
+            "application_services_state", {"state": state.value}, ["as_id"]
         )
         # NB: This assumes this class is linked with ApplicationServiceStore
         as_list = self.get_app_services()
@@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore(
             desc="get_appservice_state",
         )
         if result:
-            return result.get("state")
+            return ApplicationServiceState(result.get("state"))
         return None
 
     async def set_appservice_state(
@@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore(
             state: The connectivity state to apply.
         """
         await self.db_pool.simple_upsert(
-            "application_services_state", {"as_id": service.id}, {"state": state}
+            "application_services_state", {"as_id": service.id}, {"state": state.value}
         )
 
     async def create_appservice_txn(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9ccc66e589..d5a4a661cd 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return {d["device_id"]: d for d in devices}
 
+    async def get_devices_by_auth_provider_session_id(
+        self, auth_provider_id: str, auth_provider_session_id: str
+    ) -> List[Dict[str, Any]]:
+        """Retrieve the list of devices associated with a SSO IdP session ID.
+
+        Args:
+            auth_provider_id: The SSO IdP ID as defined in the server config
+            auth_provider_session_id: The session ID within the IdP
+        Returns:
+            A list of dicts containing the device_id and the user_id of each device
+        """
+        return await self.db_pool.simple_select_list(
+            table="device_auth_providers",
+            keyvalues={
+                "auth_provider_id": auth_provider_id,
+                "auth_provider_session_id": auth_provider_session_id,
+            },
+            retcols=("user_id", "device_id"),
+            desc="get_devices_by_auth_provider_session_id",
+        )
+
     @trace
     async def get_device_updates_by_remote(
         self, destination: str, from_stream_id: int, limit: int
@@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def store_device(
-        self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
+        self,
+        user_id: str,
+        device_id: str,
+        initial_device_display_name: Optional[str],
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> bool:
         """Ensure the given device is known; add it to the store if not
 
@@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             device_id: id of device
             initial_device_display_name: initial displayname of the device.
                 Ignored if device exists.
+            auth_provider_id: The SSO IdP the user used, if any.
+            auth_provider_session_id: The session ID (sid) got from a OIDC login.
 
         Returns:
             Whether the device was inserted or an existing device existed with that ID.
@@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 if hidden:
                     raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
 
+            if auth_provider_id and auth_provider_session_id:
+                await self.db_pool.simple_insert(
+                    "device_auth_providers",
+                    values={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "auth_provider_id": auth_provider_id,
+                        "auth_provider_session_id": auth_provider_session_id,
+                    },
+                    desc="store_device_auth_provider",
+                )
+
             self.device_id_exists_cache.set(key, True)
             return inserted
         except StoreError:
@@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 keyvalues={"user_id": user_id},
             )
 
+            self.db_pool.simple_delete_many_txn(
+                txn,
+                table="device_auth_providers",
+                column="device_id",
+                values=device_ids,
+                keyvalues={"user_id": user_id},
+            )
+
         await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
         for device_id in device_ids:
             self.device_id_exists_cache.invalidate((user_id, device_id))
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ef5d1ef01e..9580a40785 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1552,9 +1552,9 @@ class EventFederationStore(EventFederationWorkerStore):
                 DELETE FROM event_auth
                 WHERE event_id IN (
                     SELECT event_id FROM events
-                    LEFT JOIN state_events USING (room_id, event_id)
+                    LEFT JOIN state_events AS se USING (room_id, event_id)
                     WHERE ? <= stream_ordering AND stream_ordering < ?
-                        AND state_key IS null
+                        AND se.state_key IS null
                 )
             """
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d957e770dc..3efdd0c920 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,6 +16,7 @@ import logging
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 
 import attr
+from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [
 ]
 
 
+class BasePushAction(TypedDict):
+    event_id: str
+    actions: List[Union[dict, str]]
+
+
+class HttpPushAction(BasePushAction):
+    room_id: str
+    stream_ordering: int
+
+
+class EmailPushAction(HttpPushAction):
+    received_ts: Optional[int]
+
+
 def _serialize_action(actions, is_highlight):
     """Custom serializer for actions. This allows us to "compress" common actions.
 
@@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         min_stream_ordering: int,
         max_stream_ordering: int,
         limit: int = 20,
-    ) -> List[dict]:
+    ) -> List[HttpPushAction]:
         """Get a list of the most recent unread push actions for a given user,
         within the given stream ordering range. Called by the httppusher.
 
@@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         min_stream_ordering: int,
         max_stream_ordering: int,
         limit: int = 20,
-    ) -> List[dict]:
+    ) -> List[EmailPushAction]:
         """Get a list of the most recent unread push actions for a given user,
         within the given stream ordering range. Called by the emailpusher
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 06832221ad..4e528612ea 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 import itertools
 import logging
-from collections import OrderedDict, namedtuple
+from collections import OrderedDict
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.logging.utils import log_function
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
@@ -64,9 +65,6 @@ event_counter = Counter(
 )
 
 
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
 @attr.s(slots=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
@@ -108,23 +106,30 @@ class PersistEventsStore:
         self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
-        # Ideally we'd move these ID gens here, unfortunately some other ID
-        # generators are chained off them so doing so is a bit of a PITA.
-        self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
-        self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
-
         # This should only exist on instances that are configured to write
         assert (
             hs.get_instance_name() in hs.config.worker.writers.events
         ), "Can only instantiate EventsStore on master"
 
+        # Since we have been configured to write, we ought to have id generators,
+        # rather than id trackers.
+        assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
+        assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
+
+        # Ideally we'd move these ID gens here, unfortunately some other ID
+        # generators are chained off them so doing so is a bit of a PITA.
+        self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
+        self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+
     async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
+        *,
         current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremeties: Dict[str, List[str]],
-        backfilled: bool = False,
+        use_negative_stream_ordering: bool = False,
+        inhibit_local_membership_updates: bool = False,
     ) -> None:
         """Persist a set of events alongside updates to the current state and
         forward extremities tables.
@@ -137,7 +142,14 @@ class PersistEventsStore:
                 room state
             new_forward_extremities: Map from room_id to list of event IDs
                 that are the new forward extremities of the room.
-            backfilled
+            use_negative_stream_ordering: Whether to start stream_ordering on
+                the negative side and decrement. This should be set as True
+                for backfilled events because backfilled events get a negative
+                stream ordering so they don't come down incremental `/sync`.
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
 
         Returns:
             Resolves when the events have been persisted
@@ -159,7 +171,7 @@ class PersistEventsStore:
         #
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
-        if backfilled:
+        if use_negative_stream_ordering:
             stream_ordering_manager = self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
@@ -176,13 +188,13 @@ class PersistEventsStore:
                 "persist_events",
                 self._persist_events_txn,
                 events_and_contexts=events_and_contexts,
-                backfilled=backfilled,
+                inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremeties=new_forward_extremeties,
             )
             persist_event_counter.inc(len(events_and_contexts))
 
-            if not backfilled:
+            if stream < 0:
                 # backfilled events have negative stream orderings, so we don't
                 # want to set the event_persisted_position to that.
                 synapse.metrics.event_persisted_position.set(
@@ -316,8 +328,9 @@ class PersistEventsStore:
     def _persist_events_txn(
         self,
         txn: LoggingTransaction,
+        *,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-        backfilled: bool,
+        inhibit_local_membership_updates: bool = False,
         state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
         new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
     ):
@@ -330,7 +343,10 @@ class PersistEventsStore:
         Args:
             txn
             events_and_contexts: events to persist
-            backfilled: True if the events were backfilled
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
             delete_existing True to purge existing table rows for the events
                 from the database. This is useful when retrying due to
                 IntegrityError.
@@ -363,9 +379,7 @@ class PersistEventsStore:
             events_and_contexts
         )
 
-        self._update_room_depths_txn(
-            txn, events_and_contexts=events_and_contexts, backfilled=backfilled
-        )
+        self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
 
         # _update_outliers_txn filters out any events which have already been
         # persisted, and returns the filtered list.
@@ -398,7 +412,7 @@ class PersistEventsStore:
             txn,
             events_and_contexts=events_and_contexts,
             all_events_and_contexts=all_events_and_contexts,
-            backfilled=backfilled,
+            inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
         # We call this last as it assumes we've inserted the events into
@@ -561,9 +575,9 @@ class PersistEventsStore:
         # fetch their auth event info.
         while missing_auth_chains:
             sql = """
-                SELECT event_id, events.type, state_key, chain_id, sequence_number
+                SELECT event_id, events.type, se.state_key, chain_id, sequence_number
                 FROM events
-                INNER JOIN state_events USING (event_id)
+                INNER JOIN state_events AS se USING (event_id)
                 LEFT JOIN event_auth_chains USING (event_id)
                 WHERE
             """
@@ -1200,7 +1214,6 @@ class PersistEventsStore:
         self,
         txn,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-        backfilled: bool,
     ):
         """Update min_depth for each room
 
@@ -1208,13 +1221,18 @@ class PersistEventsStore:
             txn (twisted.enterprise.adbapi.Connection): db connection
             events_and_contexts (list[(EventBase, EventContext)]): events
                 we are persisting
-            backfilled (bool): True if the events were backfilled
         """
         depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
-            if not backfilled:
+            # Then update the `stream_ordering` position to mark the latest
+            # event as the front of the room. This should not be done for
+            # backfilled events because backfilled events have negative
+            # stream_ordering and happened in the past so we know that we don't
+            # need to update the stream_ordering tip/front for the room.
+            assert event.internal_metadata.stream_ordering is not None
+            if event.internal_metadata.stream_ordering >= 0:
                 txn.call_after(
                     self.store._events_stream_cache.entity_has_changed,
                     event.room_id,
@@ -1427,7 +1445,12 @@ class PersistEventsStore:
         return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
     def _update_metadata_tables_txn(
-        self, txn, events_and_contexts, all_events_and_contexts, backfilled
+        self,
+        txn,
+        *,
+        events_and_contexts,
+        all_events_and_contexts,
+        inhibit_local_membership_updates: bool = False,
     ):
         """Update all the miscellaneous tables for new events
 
@@ -1439,7 +1462,10 @@ class PersistEventsStore:
                 events that we were going to persist. This includes events
                 we've already persisted, etc, that wouldn't appear in
                 events_and_context.
-            backfilled (bool): True if the events were backfilled
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
         """
 
         # Insert all the push actions into the event_push_actions table.
@@ -1513,7 +1539,7 @@ class PersistEventsStore:
                 for event, _ in events_and_contexts
                 if event.type == EventTypes.Member
             ],
-            backfilled=backfilled,
+            inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
         # Insert event_reference_hashes table.
@@ -1553,11 +1579,13 @@ class PersistEventsStore:
         for row in rows:
             event = ev_map[row["event_id"]]
             if not row["rejects"] and not row["redacts"]:
-                to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
+                to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
         def prefill():
             for cache_entry in to_prefill:
-                self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
+                self.store._get_event_cache.set(
+                    (cache_entry.event.event_id,), cache_entry
+                )
 
         txn.call_after(prefill)
 
@@ -1638,8 +1666,19 @@ class PersistEventsStore:
             txn, table="event_reference_hashes", values=vals
         )
 
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database."""
+    def _store_room_members_txn(
+        self, txn, events, *, inhibit_local_membership_updates: bool = False
+    ):
+        """
+        Store a room member in the database.
+        Args:
+            txn: The transaction to use.
+            events: List of events to store.
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
+        """
 
         def non_null_str_or_none(val: Any) -> Optional[str]:
             return val if isinstance(val, str) and "\u0000" not in val else None
@@ -1682,7 +1721,7 @@ class PersistEventsStore:
             # band membership", like a remote invite or a rejection of a remote invite.
             if (
                 self.is_mine_id(event.state_key)
-                and not backfilled
+                and not inhibit_local_membership_updates
                 and event.internal_metadata.is_outlier()
                 and event.internal_metadata.is_out_of_band_membership()
             ):
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bf316d5b..c7b660ac5a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
 import logging
 import threading
 from typing import (
+    TYPE_CHECKING,
+    Any,
     Collection,
     Container,
     Dict,
     Iterable,
     List,
+    NoReturn,
     Optional,
     Set,
     Tuple,
+    cast,
     overload,
 )
 
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
+    RoomVersion,
     RoomVersions,
 )
 from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
@@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
 # control how we batch/bulk fetch events from the database.
 # The values are plucked out of thing air to make initial sync run faster
 # on jki.re
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
 
 
 @attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
     event: EventBase
     redacted_event: Optional[EventBase]
 
@@ -129,7 +145,7 @@ class _EventRow:
     json: str
     internal_metadata: str
     format_version: Optional[int]
-    room_version_id: Optional[int]
+    room_version_id: Optional[str]
     rejected_reason: Optional[str]
     redactions: List[str]
     outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
     # options controlling this.
     USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
+        self._stream_id_gen: AbstractStreamIdTracker
+        self._backfill_id_gen: AbstractStreamIdTracker
         if isinstance(database.engine, PostgresEngine):
             # If we're using Postgres than we can use `MultiWriterIdGenerator`
             # regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
                 5 * 60 * 1000,
             )
 
-        self._get_event_cache = LruCache(
+        self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
         )
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
         # ID to cache entry. Note that the returned dict may not have the
         # requested event in it if the event isn't in the DB.
         self._current_event_fetches: Dict[
-            str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+            str, ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = {}
 
         self._event_fetch_lock = threading.Condition()
-        self._event_fetch_list = []
+        self._event_fetch_list: List[
+            Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+        ] = []
         self._event_fetch_ongoing = 0
         event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
 
         # We define this sequence here so that it can be referenced from both
         # the DataStore and PersistEventStore.
-        def get_chain_id_txn(txn):
+        def get_chain_id_txn(txn: Cursor) -> int:
             txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
-            return txn.fetchone()[0]
+            return cast(Tuple[int], txn.fetchone())[0]
 
         self.event_chain_id_gen = build_sequence_generator(
             db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
             id_column="chain_id",
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == EventsStream.NAME:
             self._stream_id_gen.advance(instance_name, token)
         elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
-        get_prev_content: bool = False,
-        allow_rejected: bool = False,
-        allow_none: Literal[False] = False,
-        check_room_id: Optional[str] = None,
+        get_prev_content: bool = ...,
+        allow_rejected: bool = ...,
+        allow_none: Literal[False] = ...,
+        check_room_id: Optional[str] = ...,
     ) -> EventBase:
         ...
 
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
-        get_prev_content: bool = False,
-        allow_rejected: bool = False,
-        allow_none: Literal[True] = False,
-        check_room_id: Optional[str] = None,
+        get_prev_content: bool = ...,
+        allow_rejected: bool = ...,
+        allow_none: Literal[True] = ...,
+        check_room_id: Optional[str] = ...,
     ) -> Optional[EventBase]:
         ...
 
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_events(
         self,
-        event_ids: Iterable[str],
+        event_ids: Collection[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def _get_events_from_cache_or_db(
         self, event_ids: Iterable[str], allow_rejected: bool = False
-    ) -> Dict[str, _EventCacheEntry]:
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the cache or the database.
 
         If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
         # same dict into itself N times).
         already_fetching_ids: Set[str] = set()
         already_fetching_deferreds: Set[
-            ObservableDeferred[Dict[str, _EventCacheEntry]]
+            ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = set()
 
         for event_id in missing_events_ids:
@@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
             # function returning more events than requested, but that can happen
             # already due to `_get_events_from_db`).
             fetching_deferred: ObservableDeferred[
-                Dict[str, _EventCacheEntry]
-            ] = ObservableDeferred(defer.Deferred())
+                Dict[str, EventCacheEntry]
+            ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
             for event_id in missing_events_ids:
                 self._current_event_fetches[event_id] = fetching_deferred
 
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    def _invalidate_get_event_cache(self, event_id):
+    def _invalidate_get_event_cache(self, event_id: str) -> None:
         self._get_event_cache.invalidate((event_id,))
 
     def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
-    ) -> Dict[str, _EventCacheEntry]:
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch events from the caches.
 
         May return rejected events.
@@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
             for e in state_to_include.values()
         ]
 
-    def _do_fetch(self, conn: Connection) -> None:
+    def _maybe_start_fetch_thread(self) -> None:
+        """Starts an event fetch thread if we are not yet at the maximum number."""
+        with self._event_fetch_lock:
+            if (
+                self._event_fetch_list
+                and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+            ):
+                self._event_fetch_ongoing += 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+                # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+                should_start = True
+            else:
+                should_start = False
+
+        if should_start:
+            run_as_background_process("fetch_events", self._fetch_thread)
+
+    async def _fetch_thread(self) -> None:
+        """Services requests for events from `_event_fetch_list`."""
+        exc = None
+        try:
+            await self.db_pool.runWithConnection(self._fetch_loop)
+        except BaseException as e:
+            exc = e
+            raise
+        finally:
+            should_restart = False
+            event_fetches_to_fail = []
+            with self._event_fetch_lock:
+                self._event_fetch_ongoing -= 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+                # There may still be work remaining in `_event_fetch_list` if we
+                # failed, or it was added in between us deciding to exit and
+                # decrementing `_event_fetch_ongoing`.
+                if self._event_fetch_list:
+                    if exc is None:
+                        # We decided to exit, but then some more work was added
+                        # before `_event_fetch_ongoing` was decremented.
+                        # If a new event fetch thread was not started, we should
+                        # restart ourselves since the remaining event fetch threads
+                        # may take a while to get around to the new work.
+                        #
+                        # Unfortunately it is not possible to tell whether a new
+                        # event fetch thread was started, so we restart
+                        # unconditionally. If we are unlucky, we will end up with
+                        # an idle fetch thread, but it will time out after
+                        # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+                        # in any case.
+                        #
+                        # Note that multiple fetch threads may run down this path at
+                        # the same time.
+                        should_restart = True
+                    elif isinstance(exc, Exception):
+                        if self._event_fetch_ongoing == 0:
+                            # We were the last remaining fetcher and failed.
+                            # Fail any outstanding fetches since no one else will
+                            # handle them.
+                            event_fetches_to_fail = self._event_fetch_list
+                            self._event_fetch_list = []
+                        else:
+                            # We weren't the last remaining fetcher, so another
+                            # fetcher will pick up the work. This will either happen
+                            # after their existing work, however long that takes,
+                            # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+                            # they are idle.
+                            pass
+                    else:
+                        # The exception is a `SystemExit`, `KeyboardInterrupt` or
+                        # `GeneratorExit`. Don't try to do anything clever here.
+                        pass
+
+            if should_restart:
+                # We exited cleanly but noticed more work.
+                self._maybe_start_fetch_thread()
+
+            if event_fetches_to_fail:
+                # We were the last remaining fetcher and failed.
+                # Fail any outstanding fetches since no one else will handle them.
+                assert exc is not None
+                with PreserveLoggingContext():
+                    for _, deferred in event_fetches_to_fail:
+                        deferred.errback(exc)
+
+    def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
         """
-        try:
-            i = 0
-            while True:
-                with self._event_fetch_lock:
-                    event_list = self._event_fetch_list
-                    self._event_fetch_list = []
-
-                    if not event_list:
-                        single_threaded = self.database_engine.single_threaded
-                        if (
-                            not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
-                            or single_threaded
-                            or i > EVENT_QUEUE_ITERATIONS
-                        ):
-                            break
-                        else:
-                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
-                            i += 1
-                            continue
-                    i = 0
+        i = 0
+        while True:
+            with self._event_fetch_lock:
+                event_list = self._event_fetch_list
+                self._event_fetch_list = []
+
+                if not event_list:
+                    # There are no requests waiting. If we haven't yet reached the
+                    # maximum iteration limit, wait for some more requests to turn up.
+                    # Otherwise, bail out.
+                    single_threaded = self.database_engine.single_threaded
+                    if (
+                        not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+                        or single_threaded
+                        or i > EVENT_QUEUE_ITERATIONS
+                    ):
+                        return
+
+                    self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                    i += 1
+                    continue
+                i = 0
 
-                self._fetch_event_list(conn, event_list)
-        finally:
-            self._event_fetch_ongoing -= 1
-            event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+            self._fetch_event_list(conn, event_list)
 
     def _fetch_event_list(
-        self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+        self,
+        conn: LoggingDatabaseConnection,
+        event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
     ) -> None:
         """Handle a load of requests from the _event_fetch_list queue
 
@@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
                 )
 
                 # We only want to resolve deferreds from the main thread
-                def fire():
+                def fire() -> None:
                     for _, d in event_list:
                         d.callback(row_dict)
 
@@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
                 logger.exception("do_fetch")
 
                 # We only want to resolve deferreds from the main thread
-                def fire(evs, exc):
-                    for _, d in evs:
-                        if not d.called:
-                            with PreserveLoggingContext():
-                                d.errback(exc)
+                def fire_errback(exc: Exception) -> None:
+                    for _, d in event_list:
+                        d.errback(exc)
 
                 with PreserveLoggingContext():
-                    self.hs.get_reactor().callFromThread(fire, event_list, e)
+                    self.hs.get_reactor().callFromThread(fire_errback, e)
 
     async def _get_events_from_db(
-        self, event_ids: Iterable[str]
-    ) -> Dict[str, _EventCacheEntry]:
+        self, event_ids: Collection[str]
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the database.
 
         May return rejected events.
@@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
             map from event id to result. May return extra events which
             weren't asked for.
         """
-        fetched_events = {}
+        fetched_event_ids: Set[str] = set()
+        fetched_events: Dict[str, _EventRow] = {}
         events_to_fetch = event_ids
 
         while events_to_fetch:
             row_map = await self._enqueue_events(events_to_fetch)
 
             # we need to recursively fetch any redactions of those events
-            redaction_ids = set()
+            redaction_ids: Set[str] = set()
             for event_id in events_to_fetch:
                 row = row_map.get(event_id)
-                fetched_events[event_id] = row
+                fetched_event_ids.add(event_id)
                 if row:
+                    fetched_events[event_id] = row
                     redaction_ids.update(row.redactions)
 
-            events_to_fetch = redaction_ids.difference(fetched_events.keys())
+            events_to_fetch = redaction_ids.difference(fetched_event_ids)
             if events_to_fetch:
                 logger.debug("Also fetching redaction events %s", events_to_fetch)
 
         # build a map from event_id to EventBase
-        event_map = {}
+        event_map: Dict[str, EventBase] = {}
         for event_id, row in fetched_events.items():
-            if not row:
-                continue
             assert row.event_id == event_id
 
             rejected_reason = row.rejected_reason
@@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             room_version_id = row.room_version_id
 
+            room_version: Optional[RoomVersion]
             if not room_version_id:
                 # this should only happen for out-of-band membership events which
                 # arrived before #6983 landed. For all other events, we should have
@@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         # finally, we can decide whether each one needs redacting, and build
         # the cache entries.
-        result_map = {}
+        result_map: Dict[str, EventCacheEntry] = {}
         for event_id, original_ev in event_map.items():
             redactions = fetched_events[event_id].redactions
             redacted_event = self._maybe_redact_event_row(
                 original_ev, redactions, event_map
             )
 
-            cache_entry = _EventCacheEntry(
+            cache_entry = EventCacheEntry(
                 event=original_ev, redacted_event=redacted_event
             )
 
@@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+    async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
@@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
             that weren't requested.
         """
 
-        events_d = defer.Deferred()
+        events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
         with self._event_fetch_lock:
             self._event_fetch_list.append((events, events_d))
-
             self._event_fetch_lock.notify()
 
-            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
-                self._event_fetch_ongoing += 1
-                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
-                should_start = True
-            else:
-                should_start = False
-
-        if should_start:
-            run_as_background_process(
-                "fetch_events", self.db_pool.runWithConnection, self._do_fetch
-            )
+        self._maybe_start_fetch_thread()
 
         logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():
@@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
         # no valid redaction found for this event
         return None
 
-    async def have_events_in_timeline(self, event_ids):
+    async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
@@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
             event_ids: events we are looking for
 
         Returns:
-            set[str]: The events we have already seen.
+            The set of events we have already seen.
         """
         res = await self._have_seen_events_dict(
             (room_id, event_id) for event_id in event_ids
@@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
         }
         results = {x: True for x in cache_results}
 
-        def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+        def have_seen_events_txn(
+            txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+        ) -> None:
             # we deliberately do *not* query the database for room_id, to make the
             # query an index-only lookup on `events_event_id_key`.
             #
@@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
         return results
 
     @cached(max_entries=100000, tree=True)
-    async def have_seen_event(self, room_id: str, event_id: str):
+    async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
         # this only exists for the benefit of the @cachedList descriptor on
         # _have_seen_events_dict
         raise NotImplementedError()
 
-    def _get_current_state_event_counts_txn(self, txn, room_id):
+    def _get_current_state_event_counts_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> int:
         """
         See get_current_state_event_counts.
         """
@@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
             room_id,
         )
 
-    async def get_room_complexity(self, room_id):
+    async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
         """
         Get a rough approximation of the complexity of the room. This is used by
         remote servers to decide whether they wish to join the room or not.
@@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
         more resources.
 
         Args:
-            room_id (str)
+            room_id: The room ID to query.
 
         Returns:
-            dict[str:int] of complexity version to complexity.
+            dict[str:float] of complexity version to complexity.
         """
         state_events = await self.get_current_state_event_counts(room_id)
 
@@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {"v1": complexity_v1}
 
-    def get_current_events_token(self):
+    def get_current_events_token(self) -> int:
         """The current maximum token that events have reached"""
         return self._stream_id_gen.get_current_token()
 
     async def get_all_new_forward_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> List[Tuple]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -1295,13 +1403,15 @@ class EventsWorkerStore(SQLBaseStore):
             EventsStreamRow.
         """
 
-        def get_all_new_forward_event_rows(txn):
+        def get_all_new_forward_event_rows(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " LEFT JOIN room_memberships USING (event_id)"
                 " LEFT JOIN rejections USING (event_id)"
@@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, instance_name, limit))
-            return txn.fetchall()
+            return cast(
+                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_ex_outlier_stream_rows(
         self, instance_name: str, last_id: int, current_id: int
-    ) -> List[Tuple]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
@@ -1332,14 +1444,16 @@ class EventsWorkerStore(SQLBaseStore):
             EventsStreamRow.
         """
 
-        def get_ex_outlier_stream_rows_txn(txn):
+        def get_ex_outlier_stream_rows_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " LEFT JOIN room_memberships USING (event_id)"
                 " LEFT JOIN rejections USING (event_id)"
@@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
             )
 
             txn.execute(sql, (last_id, current_id, instance_name))
-            return txn.fetchall()
+            return cast(
+                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_new_backfill_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+    ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
         """Get updates for backfill replication stream, including all new
         backfilled events and events that have gone from being outliers to not.
 
@@ -1386,13 +1502,15 @@ class EventsWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_new_backfill_event_rows(txn):
+        def get_all_new_backfill_event_rows(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
             sql = (
                 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " se.state_key, redacts, relates_to_id"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > stream_ordering AND stream_ordering >= ?"
                 "  AND instance_name = ?"
@@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (-last_id, -current_id, instance_name, limit))
-            new_event_updates = [(row[0], row[1:]) for row in txn]
+            new_event_updates: List[
+                Tuple[int, Tuple[str, str, str, str, str, str]]
+            ] = []
+            row: Tuple[int, str, str, str, str, str, str]
+            # Type safety: iterating over `txn` yields `Tuple`, i.e.
+            # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+            # variadic tuple to a fixed length tuple and flags it up as an error.
+            for row in txn:  # type: ignore[assignment]
+                new_event_updates.append((row[0], row[1:]))
 
             limited = False
             if len(new_event_updates) == limit:
@@ -1411,11 +1537,11 @@ class EventsWorkerStore(SQLBaseStore):
 
             sql = (
                 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " se.state_key, redacts, relates_to_id"
                 " FROM events AS e"
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > event_stream_ordering"
                 " AND event_stream_ordering >= ?"
@@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
                 " ORDER BY event_stream_ordering DESC"
             )
             txn.execute(sql, (-last_id, -upper_bound, instance_name))
-            new_event_updates.extend((row[0], row[1:]) for row in txn)
+            # Type safety: iterating over `txn` yields `Tuple`, i.e.
+            # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+            # variadic tuple to a fixed length tuple and flags it up as an error.
+            for row in txn:  # type: ignore[assignment]
+                new_event_updates.append((row[0], row[1:]))
 
             if len(new_event_updates) >= limit:
                 upper_bound = new_event_updates[-1][0]
@@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_updated_current_state_deltas(
         self, instance_name: str, from_token: int, to_token: int, target_row_count: int
-    ) -> Tuple[List[Tuple], int, bool]:
+    ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
         """Fetch updates from current_state_delta_stream
 
         Args:
@@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
                * `limited` is whether there are more updates to fetch.
         """
 
-        def get_all_updated_current_state_deltas_txn(txn):
+        def get_all_updated_current_state_deltas_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str]]:
             sql = """
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
@@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
                 ORDER BY stream_id ASC LIMIT ?
             """
             txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
 
-        def get_deltas_for_stream_id_txn(txn, stream_id):
+        def get_deltas_for_stream_id_txn(
+            txn: LoggingTransaction, stream_id: int
+        ) -> List[Tuple[int, str, str, str, str]]:
             sql = """
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
                 WHERE stream_id = ?
             """
             txn.execute(sql, [stream_id])
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
 
         # we need to make sure that, for every stream id in the results, we get *all*
         # the rows with that stream id.
 
-        rows: List[Tuple] = await self.db_pool.runInteraction(
+        rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
         )
@@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         return rows, to_token, True
 
-    async def is_event_after(self, event_id1, event_id2):
+    async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
         """Returns True if event_id1 is after event_id2 in the stream"""
         to_1, so_1 = await self.get_event_ordering(event_id1)
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
     @cached(max_entries=5000)
-    async def get_event_ordering(self, event_id):
+    async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
         res = await self.db_pool.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
@@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
             None otherwise.
         """
 
-        def get_next_event_to_expire_txn(txn):
+        def get_next_event_to_expire_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Tuple[str, int]]:
             txn.execute(
                 """
                 SELECT event_id, expiry_ts FROM event_expiry
@@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
                 """
             )
 
-            return txn.fetchone()
+            return cast(Optional[Tuple[str, int]], txn.fetchone())
 
         return await self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
         return mapping
 
     @wrap_as_background_process("_cleanup_old_transaction_ids")
-    async def _cleanup_old_transaction_ids(self):
+    async def _cleanup_old_transaction_ids(self) -> None:
         """Cleans out transaction id mappings older than 24hrs."""
 
-        def _cleanup_old_transaction_ids_txn(txn):
+        def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
             sql = """
                 DELETE FROM event_txn_id
                 WHERE inserted_ts < ?
@@ -1626,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore):
             "_cleanup_old_transaction_ids",
             _cleanup_old_transaction_ids_txn,
         )
+
+    async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
+        """Check if the given event is next to a backward gap of missing events.
+        <latest messages> A(False)--->B(False)--->C(True)--->  <gap, unknown events> <oldest messages>
+
+        Args:
+            room_id: room where the event lives
+            event_id: event to check
+
+        Returns:
+            Boolean indicating whether it's an extremity
+        """
+
+        def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
+            # If the event in question has any of its prev_events listed as a
+            # backward extremity, it's next to a gap.
+            #
+            # We can't just check the backward edges in `event_edges` because
+            # when we persist events, we will also record the prev_events as
+            # edges to the event in question regardless of whether we have those
+            # prev_events yet. We need to check whether those prev_events are
+            # backward extremities, also known as gaps, that need to be
+            # backfilled.
+            backward_extremity_query = """
+                SELECT 1 FROM event_backward_extremities
+                WHERE
+                    room_id = ?
+                    AND %s
+                LIMIT 1
+            """
+
+            # If the event in question is a backward extremity or has any of its
+            # prev_events listed as a backward extremity, it's next to a
+            # backward gap.
+            clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "event_id",
+                [event.event_id] + list(event.prev_event_ids()),
+            )
+
+            txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
+            backward_extremities = txn.fetchall()
+
+            # We consider any backward extremity as a backward gap
+            if len(backward_extremities):
+                return True
+
+            return False
+
+        return await self.db_pool.runInteraction(
+            "is_event_next_to_backward_gap_txn",
+            is_event_next_to_backward_gap_txn,
+        )
+
+    async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
+        """Check if the given event is next to a forward gap of missing events.
+        The gap in front of the latest events is not considered a gap.
+        <latest messages> A(False)--->B(False)--->C(False)--->  <gap, unknown events> <oldest messages>
+        <latest messages> A(False)--->B(False)--->  <gap, unknown events>  --->D(True)--->E(False) <oldest messages>
+
+        Args:
+            room_id: room where the event lives
+            event_id: event to check
+
+        Returns:
+            Boolean indicating whether it's an extremity
+        """
+
+        def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
+            # If the event in question is a forward extremity, we will just
+            # consider any potential forward gap as not a gap since it's one of
+            # the latest events in the room.
+            #
+            # `event_forward_extremities` does not include backfilled or outlier
+            # events so we can't rely on it to find forward gaps. We can only
+            # use it to determine whether a message is the latest in the room.
+            #
+            # We can't combine this query with the `forward_edge_query` below
+            # because if the event in question has no forward edges (isn't
+            # referenced by any other event's prev_events) but is in
+            # `event_forward_extremities`, we don't want to return 0 rows and
+            # say it's next to a gap.
+            forward_extremity_query = """
+                SELECT 1 FROM event_forward_extremities
+                WHERE
+                    room_id = ?
+                    AND event_id = ?
+                LIMIT 1
+            """
+
+            # Check to see whether the event in question is already referenced
+            # by another event. If we don't see any edges, we're next to a
+            # forward gap.
+            forward_edge_query = """
+                SELECT 1 FROM event_edges
+                /* Check to make sure the event referencing our event in question is not rejected */
+                LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+                WHERE
+                    event_edges.room_id = ?
+                    AND event_edges.prev_event_id = ?
+                    /* It's not a valid edge if the event referencing our event in
+                     * question is rejected.
+                     */
+                    AND rejections.event_id IS NULL
+                LIMIT 1
+            """
+
+            # We consider any forward extremity as the latest in the room and
+            # not a forward gap.
+            #
+            # To expand, even though there is technically a gap at the front of
+            # the room where the forward extremities are, we consider those the
+            # latest messages in the room so asking other homeservers for more
+            # is useless. The new latest messages will just be federated as
+            # usual.
+            txn.execute(forward_extremity_query, (event.room_id, event.event_id))
+            forward_extremities = txn.fetchall()
+            if len(forward_extremities):
+                return False
+
+            # If there are no forward edges to the event in question (another
+            # event hasn't referenced this event in their prev_events), then we
+            # assume there is a forward gap in the history.
+            txn.execute(forward_edge_query, (event.room_id, event.event_id))
+            forward_edges = txn.fetchall()
+            if not len(forward_edges):
+                return True
+
+            return False
+
+        return await self.db_pool.runInteraction(
+            "is_event_next_to_gap_txn",
+            is_event_next_to_gap_txn,
+        )
+
+    async def get_event_id_for_timestamp(
+        self, room_id: str, timestamp: int, direction: str
+    ) -> Optional[str]:
+        """Find the closest event to the given timestamp in the given direction.
+
+        Args:
+            room_id: Room to fetch the event from
+            timestamp: The point in time (inclusive) we should navigate from in
+                the given direction to find the closest event.
+            direction: ["f"|"b"] to indicate whether we should navigate forward
+                or backward from the given timestamp to find the closest event.
+
+        Returns:
+            The closest event_id otherwise None if we can't find any event in
+            the given direction.
+        """
+
+        sql_template = """
+            SELECT event_id FROM events
+            LEFT JOIN rejections USING (event_id)
+            WHERE
+                origin_server_ts %s ?
+                AND room_id = ?
+                /* Make sure event is not rejected */
+                AND rejections.event_id IS NULL
+            ORDER BY origin_server_ts %s
+            LIMIT 1;
+        """
+
+        def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
+            if direction == "b":
+                # Find closest event *before* a given timestamp. We use descending
+                # (which gives values largest to smallest) because we want the
+                # largest possible timestamp *before* the given timestamp.
+                comparison_operator = "<="
+                order = "DESC"
+            else:
+                # Find closest event *after* a given timestamp. We use ascending
+                # (which gives values smallest to largest) because we want the
+                # closest possible timestamp *after* the given timestamp.
+                comparison_operator = ">="
+                order = "ASC"
+
+            txn.execute(
+                sql_template % (comparison_operator, order), (timestamp, room_id)
+            )
+            row = txn.fetchone()
+            if row:
+                (event_id,) = row
+                return event_id
+
+            return None
+
+        if direction not in ("f", "b"):
+            raise ValueError("Unknown direction: %s" % (direction,))
+
+        return await self.db_pool.runInteraction(
+            "get_event_id_for_timestamp_txn",
+            get_event_id_for_timestamp_txn,
+        )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3eb30944bf..91b0576b85 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
 
         logger.info("[purge] looking for events to delete")
 
-        should_delete_expr = "state_key IS NULL"
+        should_delete_expr = "state_events.state_key IS NULL"
         should_delete_params: Tuple[Any, ...] = ()
         if not delete_local_events:
             should_delete_expr += " AND event_id NOT LIKE ?"
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fa782023d4..3b63267395 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    StreamIdGenerator,
+)
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +85,9 @@ class PushRulesWorkerStore(
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen: Union[
-                StreamIdGenerator, SlavedIdTracker
-            ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
+            self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+                db_conn, "push_rules_stream", "stream_id"
+            )
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 0e8c168667..e1ddf06916 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -106,6 +106,15 @@ class RefreshTokenLookupResult:
     has_next_access_token_been_used: bool
     """True if the next access token was already used at least once."""
 
+    expiry_ts: Optional[int]
+    """The time at which the refresh token expires and can not be used.
+    If None, the refresh token doesn't expire."""
+
+    ultimate_session_expiry_ts: Optional[int]
+    """The time at which the session comes to an end and can no longer be
+    refreshed.
+    If None, the session can be refreshed indefinitely."""
+
 
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
@@ -1626,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                     rt.user_id,
                     rt.device_id,
                     rt.next_token_id,
-                    (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
-                    at.used has_next_access_token_been_used
+                    (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
+                    at.used AS has_next_access_token_been_used,
+                    rt.expiry_ts,
+                    rt.ultimate_session_expiry_ts
                 FROM refresh_tokens rt
                 LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
                 LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
@@ -1648,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 has_next_refresh_token_been_refreshed=row[4],
                 # This column is nullable, ensure it's a boolean
                 has_next_access_token_been_used=(row[5] or False),
+                expiry_ts=row[6],
+                ultimate_session_expiry_ts=row[7],
             )
 
         return await self.db_pool.runInteraction(
@@ -1915,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         user_id: str,
         token: str,
         device_id: Optional[str],
+        expiry_ts: Optional[int],
+        ultimate_session_expiry_ts: Optional[int],
     ) -> int:
         """Adds a refresh token for the given user.
 
@@ -1922,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             user_id: The user ID.
             token: The new access token to add.
             device_id: ID of the device to associate with the refresh token.
+            expiry_ts (milliseconds since the epoch): Time after which the
+                refresh token cannot be used.
+                If None, the refresh token never expires until it has been used.
+            ultimate_session_expiry_ts (milliseconds since the epoch):
+                Time at which the session will end and can not be extended any
+                further.
+                If None, the session can be refreshed indefinitely.
         Raises:
             StoreError if there was a problem adding this.
         Returns:
@@ -1937,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 "device_id": device_id,
                 "token": token,
                 "next_token_id": None,
+                "expiry_ts": expiry_ts,
+                "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
             },
             desc="add_refresh_token_to_user",
         )
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 033a9831d6..6b2a8d06a6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
                     c.type = 'm.room.member'
-                    AND state_key = ?
+                    AND c.state_key = ?
                     AND c.membership = ?
             """
         else:
@@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
                     c.type = 'm.room.member'
-                    AND state_key = ?
+                    AND c.state_key = ?
                     AND m.membership = ?
             """
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 42dc807d17..57aab55259 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                 oldest `limit` events.
 
         Returns:
-            The list of events (in ascending order) and the token from the start
+            The list of events (in ascending stream order) and the token from the start
             of the chunk of events returned.
         """
         if from_key == to_key:
@@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         if not has_changed:
             return [], from_key
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
             # To handle tokens with a non-empty instance_map we fetch more
             # results than necessary and then filter down
             min_from_id = from_key.stream
@@ -565,6 +565,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     async def get_membership_changes_for_user(
         self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
+        """Fetch membership events for a given user.
+
+        All such events whose stream ordering `s` lies in the range
+        `from_key < s <= to_key` are returned. Events are ordered by ascending stream
+        order.
+        """
+        # Start by ruling out cases where a DB query is not necessary.
         if from_key == to_key:
             return []
 
@@ -575,7 +582,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             if not has_changed:
                 return []
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
             # To handle tokens with a non-empty instance_map we fetch more
             # results than necessary and then filter down
             min_from_id = from_key.stream
@@ -634,7 +641,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         Returns:
             A list of events and a token pointing to the start of the returned
-            events. The events returned are in ascending order.
+            events. The events returned are in ascending topological order.
         """
 
         rows, token = await self.get_recent_event_ids_for_room(
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index d7dc1f73ac..1622822552 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,6 +14,7 @@
 
 import logging
 from collections import namedtuple
+from enum import Enum
 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 import attr
@@ -44,6 +45,16 @@ _UpdateTransactionRow = namedtuple(
 )
 
 
+class DestinationSortOrder(Enum):
+    """Enum to define the sorting method used when returning destinations."""
+
+    DESTINATION = "destination"
+    RETRY_LAST_TS = "retry_last_ts"
+    RETTRY_INTERVAL = "retry_interval"
+    FAILURE_TS = "failure_ts"
+    LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering"
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class DestinationRetryTimings:
     """The current destination retry timing info for a remote server."""
@@ -480,3 +491,62 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
 
         destinations = [row[0] for row in txn]
         return destinations
+
+    async def get_destinations_paginate(
+        self,
+        start: int,
+        limit: int,
+        destination: Optional[str] = None,
+        order_by: str = DestinationSortOrder.DESTINATION.value,
+        direction: str = "f",
+    ) -> Tuple[List[JsonDict], int]:
+        """Function to retrieve a paginated list of destinations.
+        This will return a json list of destinations and the
+        total number of destinations matching the filter criteria.
+
+        Args:
+            start: start number to begin the query from
+            limit: number of rows to retrieve
+            destination: search string in destination
+            order_by: the sort order of the returned list
+            direction: sort ascending or descending
+        Returns:
+            A tuple of a list of mappings from destination to information
+            and a count of total destinations.
+        """
+
+        def get_destinations_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
+            order_by_column = DestinationSortOrder(order_by).value
+
+            if direction == "b":
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            args = []
+            where_statement = ""
+            if destination:
+                args.extend(["%" + destination.lower() + "%"])
+                where_statement = "WHERE LOWER(destination) LIKE ?"
+
+            sql_base = f"FROM destinations {where_statement} "
+            sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = f"""
+                SELECT destination, retry_last_ts, retry_interval, failure_ts,
+                last_successful_stream_ordering
+                {sql_base}
+                ORDER BY {order_by_column} {order}, destination ASC
+                LIMIT ? OFFSET ?
+            """
+            txn.execute(sql, args + [limit, start])
+            destinations = self.db_pool.cursor_to_dict(txn)
+            return destinations, count
+
+        return await self.db_pool.runInteraction(
+            "get_destinations_paginate_txn", get_destinations_paginate_txn
+        )