summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2021-12-14 14:22:01 +0000
committerOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2021-12-14 14:22:01 +0000
commit4dd9ea8f4fc002f35fa604361a792ea1d3d6671c (patch)
tree1690b0d4658301b51907afeacf9e46bda7ce1152 /synapse/storage
parentRevert accidental fast-forward merge from v1.49.0rc1 (diff)
downloadsynapse-4dd9ea8f4fc002f35fa604361a792ea1d3d6671c.tar.xz
Revert "Revert accidental fast-forward merge from v1.49.0rc1"
This reverts commit 158d73ebdd61eef33831ae5f6990acf07244fc55.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/background_updates.py192
-rw-r--r--synapse/storage/databases/main/appservice.py6
-rw-r--r--synapse/storage/databases/main/devices.py50
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py19
-rw-r--r--synapse/storage/databases/main/events.py107
-rw-r--r--synapse/storage/databases/main/events_worker.py581
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py11
-rw-r--r--synapse/storage/databases/main/registration.py28
-rw-r--r--synapse/storage/databases/main/roommember.py4
-rw-r--r--synapse/storage/databases/main/stream.py15
-rw-r--r--synapse/storage/databases/main/transactions.py70
-rw-r--r--synapse/storage/persist_events.py3
-rw-r--r--synapse/storage/schema/__init__.py6
-rw-r--r--synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql28
-rw-r--r--synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql27
-rw-r--r--synapse/storage/util/id_generators.py116
19 files changed, 1012 insertions, 261 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0623da9aa1..3056e64ff5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import DatabasePool
 from synapse.storage.types import Connection
-from synapse.types import StreamToken, get_domain_from_id
+from synapse.types import get_domain_from_id
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
         self,
         stream_name: str,
         instance_name: str,
-        token: StreamToken,
+        token: int,
         rows: Iterable[Any],
     ) -> None:
         pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index bc8364400d..d64910aded 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,12 +12,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
+from typing import (
+    TYPE_CHECKING,
+    AsyncContextManager,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Optional,
+)
+
+import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.types import Connection
 from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import Clock, json_encoder
 
 from . import engines
 
@@ -28,6 +38,45 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
+DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _BackgroundUpdateHandler:
+    """A handler for a given background update.
+
+    Attributes:
+        callback: The function to call to make progress on the background
+            update.
+        oneshot: Wether the update is likely to happen all in one go, ignoring
+            the supplied target duration, e.g. index creation. This is used by
+            the update controller to help correctly schedule the update.
+    """
+
+    callback: Callable[[JsonDict, int], Awaitable[int]]
+    oneshot: bool = False
+
+
+class _BackgroundUpdateContextManager:
+    BACKGROUND_UPDATE_INTERVAL_MS = 1000
+    BACKGROUND_UPDATE_DURATION_MS = 100
+
+    def __init__(self, sleep: bool, clock: Clock):
+        self._sleep = sleep
+        self._clock = clock
+
+    async def __aenter__(self) -> int:
+        if self._sleep:
+            await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
+
+        return self.BACKGROUND_UPDATE_DURATION_MS
+
+    async def __aexit__(self, *exc) -> None:
+        pass
+
+
 class BackgroundUpdatePerformance:
     """Tracks the how long a background update is taking to update its items"""
 
@@ -84,20 +133,22 @@ class BackgroundUpdater:
 
     MINIMUM_BACKGROUND_BATCH_SIZE = 1
     DEFAULT_BACKGROUND_BATCH_SIZE = 100
-    BACKGROUND_UPDATE_INTERVAL_MS = 1000
-    BACKGROUND_UPDATE_DURATION_MS = 100
 
     def __init__(self, hs: "HomeServer", database: "DatabasePool"):
         self._clock = hs.get_clock()
         self.db_pool = database
 
+        self._database_name = database.name()
+
         # if a background update is currently running, its name.
         self._current_background_update: Optional[str] = None
 
+        self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
+        self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
+        self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
+
         self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
-        self._background_update_handlers: Dict[
-            str, Callable[[JsonDict, int], Awaitable[int]]
-        ] = {}
+        self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
         self._all_done = False
 
         # Whether we're currently running updates
@@ -107,6 +158,83 @@ class BackgroundUpdater:
         # enable/disable background updates via the admin API.
         self.enabled = True
 
+    def register_update_controller_callbacks(
+        self,
+        on_update: ON_UPDATE_CALLBACK,
+        default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+        min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+    ) -> None:
+        """Register callbacks from a module for each hook."""
+        if self._on_update_callback is not None:
+            logger.warning(
+                "More than one module tried to register callbacks for controlling"
+                " background updates. Only the callbacks registered by the first module"
+                " (in order of appearance in Synapse's configuration file) that tried to"
+                " do so will be called."
+            )
+
+            return
+
+        self._on_update_callback = on_update
+
+        if default_batch_size is not None:
+            self._default_batch_size_callback = default_batch_size
+
+        if min_batch_size is not None:
+            self._min_batch_size_callback = min_batch_size
+
+    def _get_context_manager_for_update(
+        self,
+        sleep: bool,
+        update_name: str,
+        database_name: str,
+        oneshot: bool,
+    ) -> AsyncContextManager[int]:
+        """Get a context manager to run a background update with.
+
+        If a module has registered a `update_handler` callback, use the context manager
+        it returns.
+
+        Otherwise, returns a context manager that will return a default value, optionally
+        sleeping if needed.
+
+        Args:
+            sleep: Whether we can sleep between updates.
+            update_name: The name of the update.
+            database_name: The name of the database the update is being run on.
+            oneshot: Whether the update will complete all in one go, e.g. index creation.
+                In such cases the returned target duration is ignored.
+
+        Returns:
+            The target duration in milliseconds that the background update should run for.
+
+            Note: this is a *target*, and an iteration may take substantially longer or
+            shorter.
+        """
+        if self._on_update_callback is not None:
+            return self._on_update_callback(update_name, database_name, oneshot)
+
+        return _BackgroundUpdateContextManager(sleep, self._clock)
+
+    async def _default_batch_size(self, update_name: str, database_name: str) -> int:
+        """The batch size to use for the first iteration of a new background
+        update.
+        """
+        if self._default_batch_size_callback is not None:
+            return await self._default_batch_size_callback(update_name, database_name)
+
+        return self.DEFAULT_BACKGROUND_BATCH_SIZE
+
+    async def _min_batch_size(self, update_name: str, database_name: str) -> int:
+        """A lower bound on the batch size of a new background update.
+
+        Used to ensure that progress is always made. Must be greater than 0.
+        """
+        if self._min_batch_size_callback is not None:
+            return await self._min_batch_size_callback(update_name, database_name)
+
+        return self.MINIMUM_BACKGROUND_BATCH_SIZE
+
     def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
         """Returns the current background update, if any."""
 
@@ -135,13 +263,8 @@ class BackgroundUpdater:
         try:
             logger.info("Starting background schema updates")
             while self.enabled:
-                if sleep:
-                    await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
-
                 try:
-                    result = await self.do_next_background_update(
-                        self.BACKGROUND_UPDATE_DURATION_MS
-                    )
+                    result = await self.do_next_background_update(sleep)
                 except Exception:
                     logger.exception("Error doing update")
                 else:
@@ -203,13 +326,15 @@ class BackgroundUpdater:
 
         return not update_exists
 
-    async def do_next_background_update(self, desired_duration_ms: float) -> bool:
+    async def do_next_background_update(self, sleep: bool = True) -> bool:
         """Does some amount of work on the next queued background update
 
         Returns once some amount of work is done.
 
         Args:
-            desired_duration_ms: How long we want to spend updating.
+            sleep: Whether to limit how quickly we run background updates or
+                not.
+
         Returns:
             True if we have finished running all the background updates, otherwise False
         """
@@ -252,7 +377,19 @@ class BackgroundUpdater:
 
             self._current_background_update = upd["update_name"]
 
-        await self._do_background_update(desired_duration_ms)
+        # We have a background update to run, otherwise we would have returned
+        # early.
+        assert self._current_background_update is not None
+        update_info = self._background_update_handlers[self._current_background_update]
+
+        async with self._get_context_manager_for_update(
+            sleep=sleep,
+            update_name=self._current_background_update,
+            database_name=self._database_name,
+            oneshot=update_info.oneshot,
+        ) as desired_duration_ms:
+            await self._do_background_update(desired_duration_ms)
+
         return False
 
     async def _do_background_update(self, desired_duration_ms: float) -> int:
@@ -260,7 +397,7 @@ class BackgroundUpdater:
         update_name = self._current_background_update
         logger.info("Starting update batch on background update '%s'", update_name)
 
-        update_handler = self._background_update_handlers[update_name]
+        update_handler = self._background_update_handlers[update_name].callback
 
         performance = self._background_update_performance.get(update_name)
 
@@ -273,9 +410,14 @@ class BackgroundUpdater:
         if items_per_ms is not None:
             batch_size = int(desired_duration_ms * items_per_ms)
             # Clamp the batch size so that we always make progress
-            batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
+            batch_size = max(
+                batch_size,
+                await self._min_batch_size(update_name, self._database_name),
+            )
         else:
-            batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
+            batch_size = await self._default_batch_size(
+                update_name, self._database_name
+            )
 
         progress_json = await self.db_pool.simple_select_one_onecol(
             "background_updates",
@@ -294,6 +436,8 @@ class BackgroundUpdater:
 
         duration_ms = time_stop - time_start
 
+        performance.update(items_updated, duration_ms)
+
         logger.info(
             "Running background update %r. Processed %r items in %rms."
             " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
@@ -306,8 +450,6 @@ class BackgroundUpdater:
             batch_size,
         )
 
-        performance.update(items_updated, duration_ms)
-
         return len(self._background_update_performance)
 
     def register_background_update_handler(
@@ -331,7 +473,9 @@ class BackgroundUpdater:
             update_name: The name of the update that this code handles.
             update_handler: The function that does the update.
         """
-        self._background_update_handlers[update_name] = update_handler
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            update_handler
+        )
 
     def register_noop_background_update(self, update_name: str) -> None:
         """Register a noop handler for a background update.
@@ -453,7 +597,9 @@ class BackgroundUpdater:
             await self._end_background_update(update_name)
             return 1
 
-        self.register_background_update_handler(update_name, updater)
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            updater, oneshot=True
+        )
 
     async def _end_background_update(self, update_name: str) -> None:
         """Removes a completed background update task from the queue.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index baec35ee27..4a883dc166 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore(
             A list of ApplicationServices, which may be empty.
         """
         results = await self.db_pool.simple_select_list(
-            "application_services_state", {"state": state}, ["as_id"]
+            "application_services_state", {"state": state.value}, ["as_id"]
         )
         # NB: This assumes this class is linked with ApplicationServiceStore
         as_list = self.get_app_services()
@@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore(
             desc="get_appservice_state",
         )
         if result:
-            return result.get("state")
+            return ApplicationServiceState(result.get("state"))
         return None
 
     async def set_appservice_state(
@@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore(
             state: The connectivity state to apply.
         """
         await self.db_pool.simple_upsert(
-            "application_services_state", {"as_id": service.id}, {"state": state}
+            "application_services_state", {"as_id": service.id}, {"state": state.value}
         )
 
     async def create_appservice_txn(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9ccc66e589..d5a4a661cd 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return {d["device_id"]: d for d in devices}
 
+    async def get_devices_by_auth_provider_session_id(
+        self, auth_provider_id: str, auth_provider_session_id: str
+    ) -> List[Dict[str, Any]]:
+        """Retrieve the list of devices associated with a SSO IdP session ID.
+
+        Args:
+            auth_provider_id: The SSO IdP ID as defined in the server config
+            auth_provider_session_id: The session ID within the IdP
+        Returns:
+            A list of dicts containing the device_id and the user_id of each device
+        """
+        return await self.db_pool.simple_select_list(
+            table="device_auth_providers",
+            keyvalues={
+                "auth_provider_id": auth_provider_id,
+                "auth_provider_session_id": auth_provider_session_id,
+            },
+            retcols=("user_id", "device_id"),
+            desc="get_devices_by_auth_provider_session_id",
+        )
+
     @trace
     async def get_device_updates_by_remote(
         self, destination: str, from_stream_id: int, limit: int
@@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def store_device(
-        self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
+        self,
+        user_id: str,
+        device_id: str,
+        initial_device_display_name: Optional[str],
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> bool:
         """Ensure the given device is known; add it to the store if not
 
@@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             device_id: id of device
             initial_device_display_name: initial displayname of the device.
                 Ignored if device exists.
+            auth_provider_id: The SSO IdP the user used, if any.
+            auth_provider_session_id: The session ID (sid) got from a OIDC login.
 
         Returns:
             Whether the device was inserted or an existing device existed with that ID.
@@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 if hidden:
                     raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
 
+            if auth_provider_id and auth_provider_session_id:
+                await self.db_pool.simple_insert(
+                    "device_auth_providers",
+                    values={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "auth_provider_id": auth_provider_id,
+                        "auth_provider_session_id": auth_provider_session_id,
+                    },
+                    desc="store_device_auth_provider",
+                )
+
             self.device_id_exists_cache.set(key, True)
             return inserted
         except StoreError:
@@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 keyvalues={"user_id": user_id},
             )
 
+            self.db_pool.simple_delete_many_txn(
+                txn,
+                table="device_auth_providers",
+                column="device_id",
+                values=device_ids,
+                keyvalues={"user_id": user_id},
+            )
+
         await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
         for device_id in device_ids:
             self.device_id_exists_cache.invalidate((user_id, device_id))
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ef5d1ef01e..9580a40785 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1552,9 +1552,9 @@ class EventFederationStore(EventFederationWorkerStore):
                 DELETE FROM event_auth
                 WHERE event_id IN (
                     SELECT event_id FROM events
-                    LEFT JOIN state_events USING (room_id, event_id)
+                    LEFT JOIN state_events AS se USING (room_id, event_id)
                     WHERE ? <= stream_ordering AND stream_ordering < ?
-                        AND state_key IS null
+                        AND se.state_key IS null
                 )
             """
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d957e770dc..3efdd0c920 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,6 +16,7 @@ import logging
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 
 import attr
+from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [
 ]
 
 
+class BasePushAction(TypedDict):
+    event_id: str
+    actions: List[Union[dict, str]]
+
+
+class HttpPushAction(BasePushAction):
+    room_id: str
+    stream_ordering: int
+
+
+class EmailPushAction(HttpPushAction):
+    received_ts: Optional[int]
+
+
 def _serialize_action(actions, is_highlight):
     """Custom serializer for actions. This allows us to "compress" common actions.
 
@@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         min_stream_ordering: int,
         max_stream_ordering: int,
         limit: int = 20,
-    ) -> List[dict]:
+    ) -> List[HttpPushAction]:
         """Get a list of the most recent unread push actions for a given user,
         within the given stream ordering range. Called by the httppusher.
 
@@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         min_stream_ordering: int,
         max_stream_ordering: int,
         limit: int = 20,
-    ) -> List[dict]:
+    ) -> List[EmailPushAction]:
         """Get a list of the most recent unread push actions for a given user,
         within the given stream ordering range. Called by the emailpusher
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 06832221ad..4e528612ea 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 import itertools
 import logging
-from collections import OrderedDict, namedtuple
+from collections import OrderedDict
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.logging.utils import log_function
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
@@ -64,9 +65,6 @@ event_counter = Counter(
 )
 
 
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
 @attr.s(slots=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
@@ -108,23 +106,30 @@ class PersistEventsStore:
         self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
-        # Ideally we'd move these ID gens here, unfortunately some other ID
-        # generators are chained off them so doing so is a bit of a PITA.
-        self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
-        self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
-
         # This should only exist on instances that are configured to write
         assert (
             hs.get_instance_name() in hs.config.worker.writers.events
         ), "Can only instantiate EventsStore on master"
 
+        # Since we have been configured to write, we ought to have id generators,
+        # rather than id trackers.
+        assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
+        assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
+
+        # Ideally we'd move these ID gens here, unfortunately some other ID
+        # generators are chained off them so doing so is a bit of a PITA.
+        self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
+        self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+
     async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
+        *,
         current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremeties: Dict[str, List[str]],
-        backfilled: bool = False,
+        use_negative_stream_ordering: bool = False,
+        inhibit_local_membership_updates: bool = False,
     ) -> None:
         """Persist a set of events alongside updates to the current state and
         forward extremities tables.
@@ -137,7 +142,14 @@ class PersistEventsStore:
                 room state
             new_forward_extremities: Map from room_id to list of event IDs
                 that are the new forward extremities of the room.
-            backfilled
+            use_negative_stream_ordering: Whether to start stream_ordering on
+                the negative side and decrement. This should be set as True
+                for backfilled events because backfilled events get a negative
+                stream ordering so they don't come down incremental `/sync`.
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
 
         Returns:
             Resolves when the events have been persisted
@@ -159,7 +171,7 @@ class PersistEventsStore:
         #
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
-        if backfilled:
+        if use_negative_stream_ordering:
             stream_ordering_manager = self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
@@ -176,13 +188,13 @@ class PersistEventsStore:
                 "persist_events",
                 self._persist_events_txn,
                 events_and_contexts=events_and_contexts,
-                backfilled=backfilled,
+                inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremeties=new_forward_extremeties,
             )
             persist_event_counter.inc(len(events_and_contexts))
 
-            if not backfilled:
+            if stream < 0:
                 # backfilled events have negative stream orderings, so we don't
                 # want to set the event_persisted_position to that.
                 synapse.metrics.event_persisted_position.set(
@@ -316,8 +328,9 @@ class PersistEventsStore:
     def _persist_events_txn(
         self,
         txn: LoggingTransaction,
+        *,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-        backfilled: bool,
+        inhibit_local_membership_updates: bool = False,
         state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
         new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
     ):
@@ -330,7 +343,10 @@ class PersistEventsStore:
         Args:
             txn
             events_and_contexts: events to persist
-            backfilled: True if the events were backfilled
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
             delete_existing True to purge existing table rows for the events
                 from the database. This is useful when retrying due to
                 IntegrityError.
@@ -363,9 +379,7 @@ class PersistEventsStore:
             events_and_contexts
         )
 
-        self._update_room_depths_txn(
-            txn, events_and_contexts=events_and_contexts, backfilled=backfilled
-        )
+        self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
 
         # _update_outliers_txn filters out any events which have already been
         # persisted, and returns the filtered list.
@@ -398,7 +412,7 @@ class PersistEventsStore:
             txn,
             events_and_contexts=events_and_contexts,
             all_events_and_contexts=all_events_and_contexts,
-            backfilled=backfilled,
+            inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
         # We call this last as it assumes we've inserted the events into
@@ -561,9 +575,9 @@ class PersistEventsStore:
         # fetch their auth event info.
         while missing_auth_chains:
             sql = """
-                SELECT event_id, events.type, state_key, chain_id, sequence_number
+                SELECT event_id, events.type, se.state_key, chain_id, sequence_number
                 FROM events
-                INNER JOIN state_events USING (event_id)
+                INNER JOIN state_events AS se USING (event_id)
                 LEFT JOIN event_auth_chains USING (event_id)
                 WHERE
             """
@@ -1200,7 +1214,6 @@ class PersistEventsStore:
         self,
         txn,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-        backfilled: bool,
     ):
         """Update min_depth for each room
 
@@ -1208,13 +1221,18 @@ class PersistEventsStore:
             txn (twisted.enterprise.adbapi.Connection): db connection
             events_and_contexts (list[(EventBase, EventContext)]): events
                 we are persisting
-            backfilled (bool): True if the events were backfilled
         """
         depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
-            if not backfilled:
+            # Then update the `stream_ordering` position to mark the latest
+            # event as the front of the room. This should not be done for
+            # backfilled events because backfilled events have negative
+            # stream_ordering and happened in the past so we know that we don't
+            # need to update the stream_ordering tip/front for the room.
+            assert event.internal_metadata.stream_ordering is not None
+            if event.internal_metadata.stream_ordering >= 0:
                 txn.call_after(
                     self.store._events_stream_cache.entity_has_changed,
                     event.room_id,
@@ -1427,7 +1445,12 @@ class PersistEventsStore:
         return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
     def _update_metadata_tables_txn(
-        self, txn, events_and_contexts, all_events_and_contexts, backfilled
+        self,
+        txn,
+        *,
+        events_and_contexts,
+        all_events_and_contexts,
+        inhibit_local_membership_updates: bool = False,
     ):
         """Update all the miscellaneous tables for new events
 
@@ -1439,7 +1462,10 @@ class PersistEventsStore:
                 events that we were going to persist. This includes events
                 we've already persisted, etc, that wouldn't appear in
                 events_and_context.
-            backfilled (bool): True if the events were backfilled
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
         """
 
         # Insert all the push actions into the event_push_actions table.
@@ -1513,7 +1539,7 @@ class PersistEventsStore:
                 for event, _ in events_and_contexts
                 if event.type == EventTypes.Member
             ],
-            backfilled=backfilled,
+            inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
         # Insert event_reference_hashes table.
@@ -1553,11 +1579,13 @@ class PersistEventsStore:
         for row in rows:
             event = ev_map[row["event_id"]]
             if not row["rejects"] and not row["redacts"]:
-                to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
+                to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
         def prefill():
             for cache_entry in to_prefill:
-                self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
+                self.store._get_event_cache.set(
+                    (cache_entry.event.event_id,), cache_entry
+                )
 
         txn.call_after(prefill)
 
@@ -1638,8 +1666,19 @@ class PersistEventsStore:
             txn, table="event_reference_hashes", values=vals
         )
 
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database."""
+    def _store_room_members_txn(
+        self, txn, events, *, inhibit_local_membership_updates: bool = False
+    ):
+        """
+        Store a room member in the database.
+        Args:
+            txn: The transaction to use.
+            events: List of events to store.
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
+        """
 
         def non_null_str_or_none(val: Any) -> Optional[str]:
             return val if isinstance(val, str) and "\u0000" not in val else None
@@ -1682,7 +1721,7 @@ class PersistEventsStore:
             # band membership", like a remote invite or a rejection of a remote invite.
             if (
                 self.is_mine_id(event.state_key)
-                and not backfilled
+                and not inhibit_local_membership_updates
                 and event.internal_metadata.is_outlier()
                 and event.internal_metadata.is_out_of_band_membership()
             ):
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bf316d5b..c7b660ac5a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
 import logging
 import threading
 from typing import (
+    TYPE_CHECKING,
+    Any,
     Collection,
     Container,
     Dict,
     Iterable,
     List,
+    NoReturn,
     Optional,
     Set,
     Tuple,
+    cast,
     overload,
 )
 
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
+    RoomVersion,
     RoomVersions,
 )
 from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
@@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
 # control how we batch/bulk fetch events from the database.
 # The values are plucked out of thing air to make initial sync run faster
 # on jki.re
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
 
 
 @attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
     event: EventBase
     redacted_event: Optional[EventBase]
 
@@ -129,7 +145,7 @@ class _EventRow:
     json: str
     internal_metadata: str
     format_version: Optional[int]
-    room_version_id: Optional[int]
+    room_version_id: Optional[str]
     rejected_reason: Optional[str]
     redactions: List[str]
     outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
     # options controlling this.
     USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
+        self._stream_id_gen: AbstractStreamIdTracker
+        self._backfill_id_gen: AbstractStreamIdTracker
         if isinstance(database.engine, PostgresEngine):
             # If we're using Postgres than we can use `MultiWriterIdGenerator`
             # regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
                 5 * 60 * 1000,
             )
 
-        self._get_event_cache = LruCache(
+        self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
         )
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
         # ID to cache entry. Note that the returned dict may not have the
         # requested event in it if the event isn't in the DB.
         self._current_event_fetches: Dict[
-            str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+            str, ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = {}
 
         self._event_fetch_lock = threading.Condition()
-        self._event_fetch_list = []
+        self._event_fetch_list: List[
+            Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+        ] = []
         self._event_fetch_ongoing = 0
         event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
 
         # We define this sequence here so that it can be referenced from both
         # the DataStore and PersistEventStore.
-        def get_chain_id_txn(txn):
+        def get_chain_id_txn(txn: Cursor) -> int:
             txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
-            return txn.fetchone()[0]
+            return cast(Tuple[int], txn.fetchone())[0]
 
         self.event_chain_id_gen = build_sequence_generator(
             db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
             id_column="chain_id",
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == EventsStream.NAME:
             self._stream_id_gen.advance(instance_name, token)
         elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
-        get_prev_content: bool = False,
-        allow_rejected: bool = False,
-        allow_none: Literal[False] = False,
-        check_room_id: Optional[str] = None,
+        get_prev_content: bool = ...,
+        allow_rejected: bool = ...,
+        allow_none: Literal[False] = ...,
+        check_room_id: Optional[str] = ...,
     ) -> EventBase:
         ...
 
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
-        get_prev_content: bool = False,
-        allow_rejected: bool = False,
-        allow_none: Literal[True] = False,
-        check_room_id: Optional[str] = None,
+        get_prev_content: bool = ...,
+        allow_rejected: bool = ...,
+        allow_none: Literal[True] = ...,
+        check_room_id: Optional[str] = ...,
     ) -> Optional[EventBase]:
         ...
 
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_events(
         self,
-        event_ids: Iterable[str],
+        event_ids: Collection[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def _get_events_from_cache_or_db(
         self, event_ids: Iterable[str], allow_rejected: bool = False
-    ) -> Dict[str, _EventCacheEntry]:
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the cache or the database.
 
         If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
         # same dict into itself N times).
         already_fetching_ids: Set[str] = set()
         already_fetching_deferreds: Set[
-            ObservableDeferred[Dict[str, _EventCacheEntry]]
+            ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = set()
 
         for event_id in missing_events_ids:
@@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
             # function returning more events than requested, but that can happen
             # already due to `_get_events_from_db`).
             fetching_deferred: ObservableDeferred[
-                Dict[str, _EventCacheEntry]
-            ] = ObservableDeferred(defer.Deferred())
+                Dict[str, EventCacheEntry]
+            ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
             for event_id in missing_events_ids:
                 self._current_event_fetches[event_id] = fetching_deferred
 
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    def _invalidate_get_event_cache(self, event_id):
+    def _invalidate_get_event_cache(self, event_id: str) -> None:
         self._get_event_cache.invalidate((event_id,))
 
     def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
-    ) -> Dict[str, _EventCacheEntry]:
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch events from the caches.
 
         May return rejected events.
@@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
             for e in state_to_include.values()
         ]
 
-    def _do_fetch(self, conn: Connection) -> None:
+    def _maybe_start_fetch_thread(self) -> None:
+        """Starts an event fetch thread if we are not yet at the maximum number."""
+        with self._event_fetch_lock:
+            if (
+                self._event_fetch_list
+                and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+            ):
+                self._event_fetch_ongoing += 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+                # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+                should_start = True
+            else:
+                should_start = False
+
+        if should_start:
+            run_as_background_process("fetch_events", self._fetch_thread)
+
+    async def _fetch_thread(self) -> None:
+        """Services requests for events from `_event_fetch_list`."""
+        exc = None
+        try:
+            await self.db_pool.runWithConnection(self._fetch_loop)
+        except BaseException as e:
+            exc = e
+            raise
+        finally:
+            should_restart = False
+            event_fetches_to_fail = []
+            with self._event_fetch_lock:
+                self._event_fetch_ongoing -= 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+                # There may still be work remaining in `_event_fetch_list` if we
+                # failed, or it was added in between us deciding to exit and
+                # decrementing `_event_fetch_ongoing`.
+                if self._event_fetch_list:
+                    if exc is None:
+                        # We decided to exit, but then some more work was added
+                        # before `_event_fetch_ongoing` was decremented.
+                        # If a new event fetch thread was not started, we should
+                        # restart ourselves since the remaining event fetch threads
+                        # may take a while to get around to the new work.
+                        #
+                        # Unfortunately it is not possible to tell whether a new
+                        # event fetch thread was started, so we restart
+                        # unconditionally. If we are unlucky, we will end up with
+                        # an idle fetch thread, but it will time out after
+                        # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+                        # in any case.
+                        #
+                        # Note that multiple fetch threads may run down this path at
+                        # the same time.
+                        should_restart = True
+                    elif isinstance(exc, Exception):
+                        if self._event_fetch_ongoing == 0:
+                            # We were the last remaining fetcher and failed.
+                            # Fail any outstanding fetches since no one else will
+                            # handle them.
+                            event_fetches_to_fail = self._event_fetch_list
+                            self._event_fetch_list = []
+                        else:
+                            # We weren't the last remaining fetcher, so another
+                            # fetcher will pick up the work. This will either happen
+                            # after their existing work, however long that takes,
+                            # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+                            # they are idle.
+                            pass
+                    else:
+                        # The exception is a `SystemExit`, `KeyboardInterrupt` or
+                        # `GeneratorExit`. Don't try to do anything clever here.
+                        pass
+
+            if should_restart:
+                # We exited cleanly but noticed more work.
+                self._maybe_start_fetch_thread()
+
+            if event_fetches_to_fail:
+                # We were the last remaining fetcher and failed.
+                # Fail any outstanding fetches since no one else will handle them.
+                assert exc is not None
+                with PreserveLoggingContext():
+                    for _, deferred in event_fetches_to_fail:
+                        deferred.errback(exc)
+
+    def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
         """
-        try:
-            i = 0
-            while True:
-                with self._event_fetch_lock:
-                    event_list = self._event_fetch_list
-                    self._event_fetch_list = []
-
-                    if not event_list:
-                        single_threaded = self.database_engine.single_threaded
-                        if (
-                            not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
-                            or single_threaded
-                            or i > EVENT_QUEUE_ITERATIONS
-                        ):
-                            break
-                        else:
-                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
-                            i += 1
-                            continue
-                    i = 0
+        i = 0
+        while True:
+            with self._event_fetch_lock:
+                event_list = self._event_fetch_list
+                self._event_fetch_list = []
+
+                if not event_list:
+                    # There are no requests waiting. If we haven't yet reached the
+                    # maximum iteration limit, wait for some more requests to turn up.
+                    # Otherwise, bail out.
+                    single_threaded = self.database_engine.single_threaded
+                    if (
+                        not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+                        or single_threaded
+                        or i > EVENT_QUEUE_ITERATIONS
+                    ):
+                        return
+
+                    self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                    i += 1
+                    continue
+                i = 0
 
-                self._fetch_event_list(conn, event_list)
-        finally:
-            self._event_fetch_ongoing -= 1
-            event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+            self._fetch_event_list(conn, event_list)
 
     def _fetch_event_list(
-        self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+        self,
+        conn: LoggingDatabaseConnection,
+        event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
     ) -> None:
         """Handle a load of requests from the _event_fetch_list queue
 
@@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
                 )
 
                 # We only want to resolve deferreds from the main thread
-                def fire():
+                def fire() -> None:
                     for _, d in event_list:
                         d.callback(row_dict)
 
@@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
                 logger.exception("do_fetch")
 
                 # We only want to resolve deferreds from the main thread
-                def fire(evs, exc):
-                    for _, d in evs:
-                        if not d.called:
-                            with PreserveLoggingContext():
-                                d.errback(exc)
+                def fire_errback(exc: Exception) -> None:
+                    for _, d in event_list:
+                        d.errback(exc)
 
                 with PreserveLoggingContext():
-                    self.hs.get_reactor().callFromThread(fire, event_list, e)
+                    self.hs.get_reactor().callFromThread(fire_errback, e)
 
     async def _get_events_from_db(
-        self, event_ids: Iterable[str]
-    ) -> Dict[str, _EventCacheEntry]:
+        self, event_ids: Collection[str]
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the database.
 
         May return rejected events.
@@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
             map from event id to result. May return extra events which
             weren't asked for.
         """
-        fetched_events = {}
+        fetched_event_ids: Set[str] = set()
+        fetched_events: Dict[str, _EventRow] = {}
         events_to_fetch = event_ids
 
         while events_to_fetch:
             row_map = await self._enqueue_events(events_to_fetch)
 
             # we need to recursively fetch any redactions of those events
-            redaction_ids = set()
+            redaction_ids: Set[str] = set()
             for event_id in events_to_fetch:
                 row = row_map.get(event_id)
-                fetched_events[event_id] = row
+                fetched_event_ids.add(event_id)
                 if row:
+                    fetched_events[event_id] = row
                     redaction_ids.update(row.redactions)
 
-            events_to_fetch = redaction_ids.difference(fetched_events.keys())
+            events_to_fetch = redaction_ids.difference(fetched_event_ids)
             if events_to_fetch:
                 logger.debug("Also fetching redaction events %s", events_to_fetch)
 
         # build a map from event_id to EventBase
-        event_map = {}
+        event_map: Dict[str, EventBase] = {}
         for event_id, row in fetched_events.items():
-            if not row:
-                continue
             assert row.event_id == event_id
 
             rejected_reason = row.rejected_reason
@@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             room_version_id = row.room_version_id
 
+            room_version: Optional[RoomVersion]
             if not room_version_id:
                 # this should only happen for out-of-band membership events which
                 # arrived before #6983 landed. For all other events, we should have
@@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         # finally, we can decide whether each one needs redacting, and build
         # the cache entries.
-        result_map = {}
+        result_map: Dict[str, EventCacheEntry] = {}
         for event_id, original_ev in event_map.items():
             redactions = fetched_events[event_id].redactions
             redacted_event = self._maybe_redact_event_row(
                 original_ev, redactions, event_map
             )
 
-            cache_entry = _EventCacheEntry(
+            cache_entry = EventCacheEntry(
                 event=original_ev, redacted_event=redacted_event
             )
 
@@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+    async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
@@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
             that weren't requested.
         """
 
-        events_d = defer.Deferred()
+        events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
         with self._event_fetch_lock:
             self._event_fetch_list.append((events, events_d))
-
             self._event_fetch_lock.notify()
 
-            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
-                self._event_fetch_ongoing += 1
-                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
-                should_start = True
-            else:
-                should_start = False
-
-        if should_start:
-            run_as_background_process(
-                "fetch_events", self.db_pool.runWithConnection, self._do_fetch
-            )
+        self._maybe_start_fetch_thread()
 
         logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():
@@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
         # no valid redaction found for this event
         return None
 
-    async def have_events_in_timeline(self, event_ids):
+    async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
@@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
             event_ids: events we are looking for
 
         Returns:
-            set[str]: The events we have already seen.
+            The set of events we have already seen.
         """
         res = await self._have_seen_events_dict(
             (room_id, event_id) for event_id in event_ids
@@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
         }
         results = {x: True for x in cache_results}
 
-        def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+        def have_seen_events_txn(
+            txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+        ) -> None:
             # we deliberately do *not* query the database for room_id, to make the
             # query an index-only lookup on `events_event_id_key`.
             #
@@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
         return results
 
     @cached(max_entries=100000, tree=True)
-    async def have_seen_event(self, room_id: str, event_id: str):
+    async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
         # this only exists for the benefit of the @cachedList descriptor on
         # _have_seen_events_dict
         raise NotImplementedError()
 
-    def _get_current_state_event_counts_txn(self, txn, room_id):
+    def _get_current_state_event_counts_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> int:
         """
         See get_current_state_event_counts.
         """
@@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
             room_id,
         )
 
-    async def get_room_complexity(self, room_id):
+    async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
         """
         Get a rough approximation of the complexity of the room. This is used by
         remote servers to decide whether they wish to join the room or not.
@@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
         more resources.
 
         Args:
-            room_id (str)
+            room_id: The room ID to query.
 
         Returns:
-            dict[str:int] of complexity version to complexity.
+            dict[str:float] of complexity version to complexity.
         """
         state_events = await self.get_current_state_event_counts(room_id)
 
@@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {"v1": complexity_v1}
 
-    def get_current_events_token(self):
+    def get_current_events_token(self) -> int:
         """The current maximum token that events have reached"""
         return self._stream_id_gen.get_current_token()
 
     async def get_all_new_forward_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> List[Tuple]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -1295,13 +1403,15 @@ class EventsWorkerStore(SQLBaseStore):
             EventsStreamRow.
         """
 
-        def get_all_new_forward_event_rows(txn):
+        def get_all_new_forward_event_rows(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " LEFT JOIN room_memberships USING (event_id)"
                 " LEFT JOIN rejections USING (event_id)"
@@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, instance_name, limit))
-            return txn.fetchall()
+            return cast(
+                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_ex_outlier_stream_rows(
         self, instance_name: str, last_id: int, current_id: int
-    ) -> List[Tuple]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
@@ -1332,14 +1444,16 @@ class EventsWorkerStore(SQLBaseStore):
             EventsStreamRow.
         """
 
-        def get_ex_outlier_stream_rows_txn(txn):
+        def get_ex_outlier_stream_rows_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " LEFT JOIN room_memberships USING (event_id)"
                 " LEFT JOIN rejections USING (event_id)"
@@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
             )
 
             txn.execute(sql, (last_id, current_id, instance_name))
-            return txn.fetchall()
+            return cast(
+                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_new_backfill_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+    ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
         """Get updates for backfill replication stream, including all new
         backfilled events and events that have gone from being outliers to not.
 
@@ -1386,13 +1502,15 @@ class EventsWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_new_backfill_event_rows(txn):
+        def get_all_new_backfill_event_rows(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
             sql = (
                 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " se.state_key, redacts, relates_to_id"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > stream_ordering AND stream_ordering >= ?"
                 "  AND instance_name = ?"
@@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (-last_id, -current_id, instance_name, limit))
-            new_event_updates = [(row[0], row[1:]) for row in txn]
+            new_event_updates: List[
+                Tuple[int, Tuple[str, str, str, str, str, str]]
+            ] = []
+            row: Tuple[int, str, str, str, str, str, str]
+            # Type safety: iterating over `txn` yields `Tuple`, i.e.
+            # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+            # variadic tuple to a fixed length tuple and flags it up as an error.
+            for row in txn:  # type: ignore[assignment]
+                new_event_updates.append((row[0], row[1:]))
 
             limited = False
             if len(new_event_updates) == limit:
@@ -1411,11 +1537,11 @@ class EventsWorkerStore(SQLBaseStore):
 
             sql = (
                 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " se.state_key, redacts, relates_to_id"
                 " FROM events AS e"
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > event_stream_ordering"
                 " AND event_stream_ordering >= ?"
@@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
                 " ORDER BY event_stream_ordering DESC"
             )
             txn.execute(sql, (-last_id, -upper_bound, instance_name))
-            new_event_updates.extend((row[0], row[1:]) for row in txn)
+            # Type safety: iterating over `txn` yields `Tuple`, i.e.
+            # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+            # variadic tuple to a fixed length tuple and flags it up as an error.
+            for row in txn:  # type: ignore[assignment]
+                new_event_updates.append((row[0], row[1:]))
 
             if len(new_event_updates) >= limit:
                 upper_bound = new_event_updates[-1][0]
@@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_updated_current_state_deltas(
         self, instance_name: str, from_token: int, to_token: int, target_row_count: int
-    ) -> Tuple[List[Tuple], int, bool]:
+    ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
         """Fetch updates from current_state_delta_stream
 
         Args:
@@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
                * `limited` is whether there are more updates to fetch.
         """
 
-        def get_all_updated_current_state_deltas_txn(txn):
+        def get_all_updated_current_state_deltas_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str]]:
             sql = """
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
@@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
                 ORDER BY stream_id ASC LIMIT ?
             """
             txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
 
-        def get_deltas_for_stream_id_txn(txn, stream_id):
+        def get_deltas_for_stream_id_txn(
+            txn: LoggingTransaction, stream_id: int
+        ) -> List[Tuple[int, str, str, str, str]]:
             sql = """
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
                 WHERE stream_id = ?
             """
             txn.execute(sql, [stream_id])
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
 
         # we need to make sure that, for every stream id in the results, we get *all*
         # the rows with that stream id.
 
-        rows: List[Tuple] = await self.db_pool.runInteraction(
+        rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
         )
@@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         return rows, to_token, True
 
-    async def is_event_after(self, event_id1, event_id2):
+    async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
         """Returns True if event_id1 is after event_id2 in the stream"""
         to_1, so_1 = await self.get_event_ordering(event_id1)
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
     @cached(max_entries=5000)
-    async def get_event_ordering(self, event_id):
+    async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
         res = await self.db_pool.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
@@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
             None otherwise.
         """
 
-        def get_next_event_to_expire_txn(txn):
+        def get_next_event_to_expire_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Tuple[str, int]]:
             txn.execute(
                 """
                 SELECT event_id, expiry_ts FROM event_expiry
@@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
                 """
             )
 
-            return txn.fetchone()
+            return cast(Optional[Tuple[str, int]], txn.fetchone())
 
         return await self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
         return mapping
 
     @wrap_as_background_process("_cleanup_old_transaction_ids")
-    async def _cleanup_old_transaction_ids(self):
+    async def _cleanup_old_transaction_ids(self) -> None:
         """Cleans out transaction id mappings older than 24hrs."""
 
-        def _cleanup_old_transaction_ids_txn(txn):
+        def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
             sql = """
                 DELETE FROM event_txn_id
                 WHERE inserted_ts < ?
@@ -1626,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore):
             "_cleanup_old_transaction_ids",
             _cleanup_old_transaction_ids_txn,
         )
+
+    async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
+        """Check if the given event is next to a backward gap of missing events.
+        <latest messages> A(False)--->B(False)--->C(True)--->  <gap, unknown events> <oldest messages>
+
+        Args:
+            room_id: room where the event lives
+            event_id: event to check
+
+        Returns:
+            Boolean indicating whether it's an extremity
+        """
+
+        def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
+            # If the event in question has any of its prev_events listed as a
+            # backward extremity, it's next to a gap.
+            #
+            # We can't just check the backward edges in `event_edges` because
+            # when we persist events, we will also record the prev_events as
+            # edges to the event in question regardless of whether we have those
+            # prev_events yet. We need to check whether those prev_events are
+            # backward extremities, also known as gaps, that need to be
+            # backfilled.
+            backward_extremity_query = """
+                SELECT 1 FROM event_backward_extremities
+                WHERE
+                    room_id = ?
+                    AND %s
+                LIMIT 1
+            """
+
+            # If the event in question is a backward extremity or has any of its
+            # prev_events listed as a backward extremity, it's next to a
+            # backward gap.
+            clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "event_id",
+                [event.event_id] + list(event.prev_event_ids()),
+            )
+
+            txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
+            backward_extremities = txn.fetchall()
+
+            # We consider any backward extremity as a backward gap
+            if len(backward_extremities):
+                return True
+
+            return False
+
+        return await self.db_pool.runInteraction(
+            "is_event_next_to_backward_gap_txn",
+            is_event_next_to_backward_gap_txn,
+        )
+
+    async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
+        """Check if the given event is next to a forward gap of missing events.
+        The gap in front of the latest events is not considered a gap.
+        <latest messages> A(False)--->B(False)--->C(False)--->  <gap, unknown events> <oldest messages>
+        <latest messages> A(False)--->B(False)--->  <gap, unknown events>  --->D(True)--->E(False) <oldest messages>
+
+        Args:
+            room_id: room where the event lives
+            event_id: event to check
+
+        Returns:
+            Boolean indicating whether it's an extremity
+        """
+
+        def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
+            # If the event in question is a forward extremity, we will just
+            # consider any potential forward gap as not a gap since it's one of
+            # the latest events in the room.
+            #
+            # `event_forward_extremities` does not include backfilled or outlier
+            # events so we can't rely on it to find forward gaps. We can only
+            # use it to determine whether a message is the latest in the room.
+            #
+            # We can't combine this query with the `forward_edge_query` below
+            # because if the event in question has no forward edges (isn't
+            # referenced by any other event's prev_events) but is in
+            # `event_forward_extremities`, we don't want to return 0 rows and
+            # say it's next to a gap.
+            forward_extremity_query = """
+                SELECT 1 FROM event_forward_extremities
+                WHERE
+                    room_id = ?
+                    AND event_id = ?
+                LIMIT 1
+            """
+
+            # Check to see whether the event in question is already referenced
+            # by another event. If we don't see any edges, we're next to a
+            # forward gap.
+            forward_edge_query = """
+                SELECT 1 FROM event_edges
+                /* Check to make sure the event referencing our event in question is not rejected */
+                LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+                WHERE
+                    event_edges.room_id = ?
+                    AND event_edges.prev_event_id = ?
+                    /* It's not a valid edge if the event referencing our event in
+                     * question is rejected.
+                     */
+                    AND rejections.event_id IS NULL
+                LIMIT 1
+            """
+
+            # We consider any forward extremity as the latest in the room and
+            # not a forward gap.
+            #
+            # To expand, even though there is technically a gap at the front of
+            # the room where the forward extremities are, we consider those the
+            # latest messages in the room so asking other homeservers for more
+            # is useless. The new latest messages will just be federated as
+            # usual.
+            txn.execute(forward_extremity_query, (event.room_id, event.event_id))
+            forward_extremities = txn.fetchall()
+            if len(forward_extremities):
+                return False
+
+            # If there are no forward edges to the event in question (another
+            # event hasn't referenced this event in their prev_events), then we
+            # assume there is a forward gap in the history.
+            txn.execute(forward_edge_query, (event.room_id, event.event_id))
+            forward_edges = txn.fetchall()
+            if not len(forward_edges):
+                return True
+
+            return False
+
+        return await self.db_pool.runInteraction(
+            "is_event_next_to_gap_txn",
+            is_event_next_to_gap_txn,
+        )
+
+    async def get_event_id_for_timestamp(
+        self, room_id: str, timestamp: int, direction: str
+    ) -> Optional[str]:
+        """Find the closest event to the given timestamp in the given direction.
+
+        Args:
+            room_id: Room to fetch the event from
+            timestamp: The point in time (inclusive) we should navigate from in
+                the given direction to find the closest event.
+            direction: ["f"|"b"] to indicate whether we should navigate forward
+                or backward from the given timestamp to find the closest event.
+
+        Returns:
+            The closest event_id otherwise None if we can't find any event in
+            the given direction.
+        """
+
+        sql_template = """
+            SELECT event_id FROM events
+            LEFT JOIN rejections USING (event_id)
+            WHERE
+                origin_server_ts %s ?
+                AND room_id = ?
+                /* Make sure event is not rejected */
+                AND rejections.event_id IS NULL
+            ORDER BY origin_server_ts %s
+            LIMIT 1;
+        """
+
+        def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
+            if direction == "b":
+                # Find closest event *before* a given timestamp. We use descending
+                # (which gives values largest to smallest) because we want the
+                # largest possible timestamp *before* the given timestamp.
+                comparison_operator = "<="
+                order = "DESC"
+            else:
+                # Find closest event *after* a given timestamp. We use ascending
+                # (which gives values smallest to largest) because we want the
+                # closest possible timestamp *after* the given timestamp.
+                comparison_operator = ">="
+                order = "ASC"
+
+            txn.execute(
+                sql_template % (comparison_operator, order), (timestamp, room_id)
+            )
+            row = txn.fetchone()
+            if row:
+                (event_id,) = row
+                return event_id
+
+            return None
+
+        if direction not in ("f", "b"):
+            raise ValueError("Unknown direction: %s" % (direction,))
+
+        return await self.db_pool.runInteraction(
+            "get_event_id_for_timestamp_txn",
+            get_event_id_for_timestamp_txn,
+        )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3eb30944bf..91b0576b85 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
 
         logger.info("[purge] looking for events to delete")
 
-        should_delete_expr = "state_key IS NULL"
+        should_delete_expr = "state_events.state_key IS NULL"
         should_delete_params: Tuple[Any, ...] = ()
         if not delete_local_events:
             should_delete_expr += " AND event_id NOT LIKE ?"
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fa782023d4..3b63267395 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    StreamIdGenerator,
+)
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +85,9 @@ class PushRulesWorkerStore(
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen: Union[
-                StreamIdGenerator, SlavedIdTracker
-            ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
+            self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+                db_conn, "push_rules_stream", "stream_id"
+            )
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 0e8c168667..e1ddf06916 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -106,6 +106,15 @@ class RefreshTokenLookupResult:
     has_next_access_token_been_used: bool
     """True if the next access token was already used at least once."""
 
+    expiry_ts: Optional[int]
+    """The time at which the refresh token expires and can not be used.
+    If None, the refresh token doesn't expire."""
+
+    ultimate_session_expiry_ts: Optional[int]
+    """The time at which the session comes to an end and can no longer be
+    refreshed.
+    If None, the session can be refreshed indefinitely."""
+
 
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
@@ -1626,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                     rt.user_id,
                     rt.device_id,
                     rt.next_token_id,
-                    (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
-                    at.used has_next_access_token_been_used
+                    (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
+                    at.used AS has_next_access_token_been_used,
+                    rt.expiry_ts,
+                    rt.ultimate_session_expiry_ts
                 FROM refresh_tokens rt
                 LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
                 LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
@@ -1648,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 has_next_refresh_token_been_refreshed=row[4],
                 # This column is nullable, ensure it's a boolean
                 has_next_access_token_been_used=(row[5] or False),
+                expiry_ts=row[6],
+                ultimate_session_expiry_ts=row[7],
             )
 
         return await self.db_pool.runInteraction(
@@ -1915,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         user_id: str,
         token: str,
         device_id: Optional[str],
+        expiry_ts: Optional[int],
+        ultimate_session_expiry_ts: Optional[int],
     ) -> int:
         """Adds a refresh token for the given user.
 
@@ -1922,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             user_id: The user ID.
             token: The new access token to add.
             device_id: ID of the device to associate with the refresh token.
+            expiry_ts (milliseconds since the epoch): Time after which the
+                refresh token cannot be used.
+                If None, the refresh token never expires until it has been used.
+            ultimate_session_expiry_ts (milliseconds since the epoch):
+                Time at which the session will end and can not be extended any
+                further.
+                If None, the session can be refreshed indefinitely.
         Raises:
             StoreError if there was a problem adding this.
         Returns:
@@ -1937,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 "device_id": device_id,
                 "token": token,
                 "next_token_id": None,
+                "expiry_ts": expiry_ts,
+                "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
             },
             desc="add_refresh_token_to_user",
         )
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 033a9831d6..6b2a8d06a6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
                     c.type = 'm.room.member'
-                    AND state_key = ?
+                    AND c.state_key = ?
                     AND c.membership = ?
             """
         else:
@@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
                     c.type = 'm.room.member'
-                    AND state_key = ?
+                    AND c.state_key = ?
                     AND m.membership = ?
             """
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 42dc807d17..57aab55259 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                 oldest `limit` events.
 
         Returns:
-            The list of events (in ascending order) and the token from the start
+            The list of events (in ascending stream order) and the token from the start
             of the chunk of events returned.
         """
         if from_key == to_key:
@@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         if not has_changed:
             return [], from_key
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
             # To handle tokens with a non-empty instance_map we fetch more
             # results than necessary and then filter down
             min_from_id = from_key.stream
@@ -565,6 +565,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     async def get_membership_changes_for_user(
         self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
+        """Fetch membership events for a given user.
+
+        All such events whose stream ordering `s` lies in the range
+        `from_key < s <= to_key` are returned. Events are ordered by ascending stream
+        order.
+        """
+        # Start by ruling out cases where a DB query is not necessary.
         if from_key == to_key:
             return []
 
@@ -575,7 +582,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             if not has_changed:
                 return []
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
             # To handle tokens with a non-empty instance_map we fetch more
             # results than necessary and then filter down
             min_from_id = from_key.stream
@@ -634,7 +641,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         Returns:
             A list of events and a token pointing to the start of the returned
-            events. The events returned are in ascending order.
+            events. The events returned are in ascending topological order.
         """
 
         rows, token = await self.get_recent_event_ids_for_room(
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index d7dc1f73ac..1622822552 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,6 +14,7 @@
 
 import logging
 from collections import namedtuple
+from enum import Enum
 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 import attr
@@ -44,6 +45,16 @@ _UpdateTransactionRow = namedtuple(
 )
 
 
+class DestinationSortOrder(Enum):
+    """Enum to define the sorting method used when returning destinations."""
+
+    DESTINATION = "destination"
+    RETRY_LAST_TS = "retry_last_ts"
+    RETTRY_INTERVAL = "retry_interval"
+    FAILURE_TS = "failure_ts"
+    LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering"
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class DestinationRetryTimings:
     """The current destination retry timing info for a remote server."""
@@ -480,3 +491,62 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
 
         destinations = [row[0] for row in txn]
         return destinations
+
+    async def get_destinations_paginate(
+        self,
+        start: int,
+        limit: int,
+        destination: Optional[str] = None,
+        order_by: str = DestinationSortOrder.DESTINATION.value,
+        direction: str = "f",
+    ) -> Tuple[List[JsonDict], int]:
+        """Function to retrieve a paginated list of destinations.
+        This will return a json list of destinations and the
+        total number of destinations matching the filter criteria.
+
+        Args:
+            start: start number to begin the query from
+            limit: number of rows to retrieve
+            destination: search string in destination
+            order_by: the sort order of the returned list
+            direction: sort ascending or descending
+        Returns:
+            A tuple of a list of mappings from destination to information
+            and a count of total destinations.
+        """
+
+        def get_destinations_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
+            order_by_column = DestinationSortOrder(order_by).value
+
+            if direction == "b":
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            args = []
+            where_statement = ""
+            if destination:
+                args.extend(["%" + destination.lower() + "%"])
+                where_statement = "WHERE LOWER(destination) LIKE ?"
+
+            sql_base = f"FROM destinations {where_statement} "
+            sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = f"""
+                SELECT destination, retry_last_ts, retry_interval, failure_ts,
+                last_successful_stream_ordering
+                {sql_base}
+                ORDER BY {order_by_column} {order}, destination ASC
+                LIMIT ? OFFSET ?
+            """
+            txn.execute(sql, args + [limit, start])
+            destinations = self.db_pool.cursor_to_dict(txn)
+            return destinations, count
+
+        return await self.db_pool.runInteraction(
+            "get_destinations_paginate_txn", get_destinations_paginate_txn
+        )
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 402f134d89..428d66a617 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -583,7 +583,8 @@ class EventsPersistenceStorage:
                 current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremeties=new_forward_extremeties,
-                backfilled=backfilled,
+                use_negative_stream_ordering=backfilled,
+                inhibit_local_membership_updates=backfilled,
             )
 
             await self._handle_potentially_left_users(potentially_left_users)
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 3a00ed6835..50d08094d5 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 65  # remember to update the list below when updating
+SCHEMA_VERSION = 66  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -46,6 +46,10 @@ Changes in SCHEMA_VERSION = 65:
     - MSC2716: Remove unique event_id constraint from insertion_event_edges
       because an insertion event can have multiple edges.
     - Remove unused tables `user_stats_historical` and `room_stats_historical`.
+
+Changes in SCHEMA_VERSION = 66:
+    - Queries on state_key columns are now disambiguated (ie, the codebase can handle
+      the `events` table having a `state_key` column).
 """
 
 
diff --git a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
new file mode 100644
index 0000000000..bdc491c817
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
@@ -0,0 +1,28 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE refresh_tokens
+  -- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens.
+  -- They may not be used after they have expired.
+  -- If null, then the refresh token's lifetime is unlimited.
+  ADD COLUMN expiry_ts BIGINT DEFAULT NULL;
+
+ALTER TABLE refresh_tokens
+  -- We also add an ultimate session expiry time (in milliseconds since the Epoch).
+  -- No matter how much the access and refresh tokens are refreshed, they cannot
+  -- be extended past this time.
+  -- If null, then the session length is unlimited.
+  ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL;
diff --git a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql
new file mode 100644
index 0000000000..a65bfb520d
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql
@@ -0,0 +1,27 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Track the auth provider used by each login as well as the session ID
+CREATE TABLE device_auth_providers (
+  user_id TEXT NOT NULL,
+  device_id TEXT NOT NULL,
+  auth_provider_id TEXT NOT NULL,
+  auth_provider_session_id TEXT NOT NULL
+);
+
+CREATE INDEX device_auth_providers_devices
+  ON device_auth_providers (user_id, device_id);
+CREATE INDEX device_auth_providers_sessions
+  ON device_auth_providers (auth_provider_id, auth_provider_session_id);
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ac56bc9a05..4ff3013908 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -89,31 +89,77 @@ def _load_current_id(
     return (max if step > 0 else min)(current_id, step)
 
 
-class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
-    @abc.abstractmethod
-    def get_next(self) -> AsyncContextManager[int]:
-        raise NotImplementedError()
+class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
+    """Tracks the "current" stream ID of a stream that may have multiple writers.
+
+    Stream IDs are monotonically increasing or decreasing integers representing write
+    transactions. The "current" stream ID is the stream ID such that all transactions
+    with equal or smaller stream IDs have completed. Since transactions may complete out
+    of order, this is not the same as the stream ID of the last completed transaction.
+
+    Completed transactions include both committed transactions and transactions that
+    have been rolled back.
+    """
 
     @abc.abstractmethod
-    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+    def advance(self, instance_name: str, new_id: int) -> None:
+        """Advance the position of the named writer to the given ID, if greater
+        than existing entry.
+        """
         raise NotImplementedError()
 
     @abc.abstractmethod
     def get_current_token(self) -> int:
+        """Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
+
+        Returns:
+            The maximum stream id.
+        """
         raise NotImplementedError()
 
     @abc.abstractmethod
     def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to `get_current_token`.
+        """
+        raise NotImplementedError()
+
+
+class AbstractStreamIdGenerator(AbstractStreamIdTracker):
+    """Generates stream IDs for a stream that may have multiple writers.
+
+    Each stream ID represents a write transaction, whose completion is tracked
+    so that the "current" stream ID of the stream can be determined.
+
+    See `AbstractStreamIdTracker` for more details.
+    """
+
+    @abc.abstractmethod
+    def get_next(self) -> AsyncContextManager[int]:
+        """
+        Usage:
+            async with stream_id_gen.get_next() as stream_id:
+                # ... persist event ...
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+        """
+        Usage:
+            async with stream_id_gen.get_next(n) as stream_ids:
+                # ... persist events ...
+        """
         raise NotImplementedError()
 
 
 class StreamIdGenerator(AbstractStreamIdGenerator):
-    """Used to generate new stream ids when persisting events while keeping
-    track of which transactions have been completed.
+    """Generates and tracks stream IDs for a stream with a single writer.
 
-    This allows us to get the "current" stream id, i.e. the stream id such that
-    all ids less than or equal to it have completed. This handles the fact that
-    persistence of events can complete out of order.
+    This class must only be used when the current Synapse process is the sole
+    writer for a stream.
 
     Args:
         db_conn(connection):  A database connection to use to fetch the
@@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         # The key and values are the same, but we never look at the values.
         self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
 
+    def advance(self, instance_name: str, new_id: int) -> None:
+        # `StreamIdGenerator` should only be used when there is a single writer,
+        # so replication should never happen.
+        raise Exception("Replication is not supported by StreamIdGenerator")
+
     def get_next(self) -> AsyncContextManager[int]:
-        """
-        Usage:
-            async with stream_id_gen.get_next() as stream_id:
-                # ... persist event ...
-        """
         with self._lock:
             self._current += self._step
             next_id = self._current
@@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         return _AsyncCtxManagerWrapper(manager())
 
     def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
-        """
-        Usage:
-            async with stream_id_gen.get_next(n) as stream_ids:
-                # ... persist events ...
-        """
         with self._lock:
             next_ids = range(
                 self._current + self._step,
@@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self) -> int:
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-
-        Returns:
-            The maximum stream id.
-        """
         with self._lock:
             if self._unfinished_ids:
                 return next(iter(self._unfinished_ids)) - self._step
@@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
             return self._current
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer.
-
-        For streams with single writers this is equivalent to
-        `get_current_token`.
-        """
         return self.get_current_token()
 
 
 class MultiWriterIdGenerator(AbstractStreamIdGenerator):
-    """An ID generator that tracks a stream that can have multiple writers.
+    """Generates and tracks stream IDs for a stream with multiple writers.
 
     Uses a Postgres sequence to coordinate ID assignment, but positions of other
     writers will only get updated when `advance` is called (by replication).
@@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         return stream_ids
 
     def get_next(self) -> AsyncContextManager[int]:
-        """
-        Usage:
-            async with stream_id_gen.get_next() as stream_id:
-                # ... persist event ...
-        """
-
         # If we have a list of instances that are allowed to write to this
         # stream, make sure we're in it.
         if self._writers and self._instance_name not in self._writers:
@@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
 
     def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
-        """
-        Usage:
-            async with stream_id_gen.get_next_mult(5) as stream_ids:
-                # ... persist events ...
-        """
-
         # If we have a list of instances that are allowed to write to this
         # stream, make sure we're in it.
         if self._writers and self._instance_name not in self._writers:
@@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             self._add_persisted_position(next_id)
 
     def get_current_token(self) -> int:
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-        """
-
         return self.get_persisted_upto_position()
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer."""
-
         # If we don't have an entry for the given instance name, we assume it's a
         # new writer.
         #
@@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             }
 
     def advance(self, instance_name: str, new_id: int) -> None:
-        """Advance the position of the named writer to the given ID, if greater
-        than existing entry.
-        """
-
         new_id *= self._return_factor
 
         with self._lock: