summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8868.misc2
-rw-r--r--changelog.d/9029.misc1
-rw-r--r--changelog.d/9098.misc1
-rw-r--r--changelog.d/9107.feature1
-rw-r--r--changelog.d/9114.bugfix1
-rw-r--r--changelog.d/9115.misc1
-rw-r--r--changelog.d/9116.bugfix1
-rw-r--r--changelog.d/9118.misc1
-rwxr-xr-xscripts/synapse_port_db2
-rw-r--r--synapse/app/homeserver.py1
-rw-r--r--synapse/handlers/oidc_handler.py246
-rw-r--r--synapse/http/federation/matrix_federation_agent.py1
-rw-r--r--synapse/storage/database.py11
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py41
-rw-r--r--synapse/storage/databases/main/events.py82
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py197
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql17
-rw-r--r--synapse/storage/databases/main/transactions.py24
-rw-r--r--synapse/storage/util/sequence.py10
-rw-r--r--synapse/util/iterutils.py2
-rw-r--r--tests/handlers/test_oidc.py93
-rw-r--r--tests/storage/test_event_chain.py114
-rw-r--r--tests/util/test_itertools.py8
23 files changed, 643 insertions, 215 deletions
diff --git a/changelog.d/8868.misc b/changelog.d/8868.misc
index 1a11e30944..346741d982 100644
--- a/changelog.d/8868.misc
+++ b/changelog.d/8868.misc
@@ -1 +1 @@
-Improve efficiency of large state resolutions for new rooms.
+Improve efficiency of large state resolutions.
diff --git a/changelog.d/9029.misc b/changelog.d/9029.misc
new file mode 100644
index 0000000000..346741d982
--- /dev/null
+++ b/changelog.d/9029.misc
@@ -0,0 +1 @@
+Improve efficiency of large state resolutions.
diff --git a/changelog.d/9098.misc b/changelog.d/9098.misc
new file mode 100644
index 0000000000..907020d428
--- /dev/null
+++ b/changelog.d/9098.misc
@@ -0,0 +1 @@
+Fix the wrong arguments being passed to `BlacklistingAgentWrapper` from `MatrixFederationAgent`. Contributed by Timothy Leung.
diff --git a/changelog.d/9107.feature b/changelog.d/9107.feature
new file mode 100644
index 0000000000..01a24dcf49
--- /dev/null
+++ b/changelog.d/9107.feature
@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.
diff --git a/changelog.d/9114.bugfix b/changelog.d/9114.bugfix
new file mode 100644
index 0000000000..211f26589d
--- /dev/null
+++ b/changelog.d/9114.bugfix
@@ -0,0 +1 @@
+Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0.
diff --git a/changelog.d/9115.misc b/changelog.d/9115.misc
new file mode 100644
index 0000000000..346741d982
--- /dev/null
+++ b/changelog.d/9115.misc
@@ -0,0 +1 @@
+Improve efficiency of large state resolutions.
diff --git a/changelog.d/9116.bugfix b/changelog.d/9116.bugfix
new file mode 100644
index 0000000000..211f26589d
--- /dev/null
+++ b/changelog.d/9116.bugfix
@@ -0,0 +1 @@
+Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0.
diff --git a/changelog.d/9118.misc b/changelog.d/9118.misc
new file mode 100644
index 0000000000..346741d982
--- /dev/null
+++ b/changelog.d/9118.misc
@@ -0,0 +1 @@
+Improve efficiency of large state resolutions.
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 22dd169bfb..69bf9110a6 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db")
 
 BOOLEAN_COLUMNS = {
     "events": ["processed", "outlier", "contains_url"],
-    "rooms": ["is_public"],
+    "rooms": ["is_public", "has_auth_chain_index"],
     "event_edges": ["is_state"],
     "presence_list": ["accepted"],
     "presence_stream": ["currently_active"],
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index cbecf23be6..57a2f5237c 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -429,7 +429,6 @@ def setup(config_options):
             oidc = hs.get_oidc_handler()
             # Loading the provider metadata also ensures the provider config is valid.
             await oidc.load_metadata()
-            await oidc.load_jwks()
 
         await _base.start(hs, config.listeners)
 
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 84754e5c9c..d6347bb1b8 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -35,6 +35,7 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
+from synapse.config.oidc_config import OidcProviderConfig
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
@@ -70,6 +71,131 @@ JWK = Dict[str, str]
 JWKS = TypedDict("JWKS", {"keys": List[JWK]})
 
 
+class OidcHandler:
+    """Handles requests related to the OpenID Connect login flow.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self._sso_handler = hs.get_sso_handler()
+
+        provider_conf = hs.config.oidc.oidc_provider
+        # we should not have been instantiated if there is no configured provider.
+        assert provider_conf is not None
+
+        self._token_generator = OidcSessionTokenGenerator(hs)
+
+        self._provider = OidcProvider(hs, self._token_generator, provider_conf)
+
+    async def load_metadata(self) -> None:
+        """Validate the config and load the metadata from the remote endpoint.
+
+        Called at startup to ensure we have everything we need.
+        """
+        await self._provider.load_metadata()
+        await self._provider.load_jwks()
+
+    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+        """Handle an incoming request to /_synapse/oidc/callback
+
+        Since we might want to display OIDC-related errors in a user-friendly
+        way, we don't raise SynapseError from here. Instead, we call
+        ``self._sso_handler.render_error`` which displays an HTML page for the error.
+
+        Most of the OpenID Connect logic happens here:
+
+          - first, we check if there was any error returned by the provider and
+            display it
+          - then we fetch the session cookie, decode and verify it
+          - the ``state`` query parameter should match with the one stored in the
+            session cookie
+
+        Once we know the session is legit, we then delegate to the OIDC Provider
+        implementation, which will exchange the code with the provider and complete the
+        login/authentication.
+
+        Args:
+            request: the incoming request from the browser.
+        """
+
+        # The provider might redirect with an error.
+        # In that case, just display it as-is.
+        if b"error" in request.args:
+            # error response from the auth server. see:
+            #  https://tools.ietf.org/html/rfc6749#section-4.1.2.1
+            #  https://openid.net/specs/openid-connect-core-1_0.html#AuthError
+            error = request.args[b"error"][0].decode()
+            description = request.args.get(b"error_description", [b""])[0].decode()
+
+            # Most of the errors returned by the provider could be due by
+            # either the provider misbehaving or Synapse being misconfigured.
+            # The only exception of that is "access_denied", where the user
+            # probably cancelled the login flow. In other cases, log those errors.
+            if error != "access_denied":
+                logger.error("Error from the OIDC provider: %s %s", error, description)
+
+            self._sso_handler.render_error(request, error, description)
+            return
+
+        # otherwise, it is presumably a successful response. see:
+        #   https://tools.ietf.org/html/rfc6749#section-4.1.2
+
+        # Fetch the session cookie
+        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
+        if session is None:
+            logger.info("No session cookie found")
+            self._sso_handler.render_error(
+                request, "missing_session", "No session cookie found"
+            )
+            return
+
+        # Remove the cookie. There is a good chance that if the callback failed
+        # once, it will fail next time and the code will already be exchanged.
+        # Removing it early avoids spamming the provider with token requests.
+        request.addCookie(
+            SESSION_COOKIE_NAME,
+            b"",
+            path="/_synapse/oidc",
+            expires="Thu, Jan 01 1970 00:00:00 UTC",
+            httpOnly=True,
+            sameSite="lax",
+        )
+
+        # Check for the state query parameter
+        if b"state" not in request.args:
+            logger.info("State parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "State parameter is missing"
+            )
+            return
+
+        state = request.args[b"state"][0].decode()
+
+        # Deserialize the session token and verify it.
+        try:
+            session_data = self._token_generator.verify_oidc_session_token(
+                session, state
+            )
+        except MacaroonDeserializationException as e:
+            logger.exception("Invalid session")
+            self._sso_handler.render_error(request, "invalid_session", str(e))
+            return
+        except MacaroonInvalidSignatureException as e:
+            logger.exception("Could not verify session")
+            self._sso_handler.render_error(request, "mismatching_session", str(e))
+            return
+
+        if b"code" not in request.args:
+            logger.info("Code parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "Code parameter is missing"
+            )
+            return
+
+        code = request.args[b"code"][0].decode()
+
+        await self._provider.handle_oidc_callback(request, session_data, code)
+
+
 class OidcError(Exception):
     """Used to catch errors when calling the token_endpoint
     """
@@ -84,21 +210,25 @@ class OidcError(Exception):
         return self.error
 
 
-class OidcHandler:
-    """Handles requests related to the OpenID Connect login flow.
+class OidcProvider:
+    """Wraps the config for a single OIDC IdentityProvider
+
+    Provides methods for handling redirect requests and callbacks via that particular
+    IdP.
     """
 
-    def __init__(self, hs: "HomeServer"):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        token_generator: "OidcSessionTokenGenerator",
+        provider: OidcProviderConfig,
+    ):
         self._store = hs.get_datastore()
 
-        self._token_generator = OidcSessionTokenGenerator(hs)
+        self._token_generator = token_generator
 
         self._callback_url = hs.config.oidc_callback_url  # type: str
 
-        provider = hs.config.oidc.oidc_provider
-        # we should not have been instantiated if there is no configured provider.
-        assert provider is not None
-
         self._scopes = provider.scopes
         self._user_profile_method = provider.user_profile_method
         self._client_auth = ClientAuth(
@@ -552,22 +682,16 @@ class OidcHandler:
             nonce=nonce,
         )
 
-    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+    async def handle_oidc_callback(
+        self, request: SynapseRequest, session_data: "OidcSessionData", code: str
+    ) -> None:
         """Handle an incoming request to /_synapse/oidc/callback
 
-        Since we might want to display OIDC-related errors in a user-friendly
-        way, we don't raise SynapseError from here. Instead, we call
-        ``self._sso_handler.render_error`` which displays an HTML page for the error.
+        By this time we have already validated the session on the synapse side, and
+        now need to do the provider-specific operations. This includes:
 
-        Most of the OpenID Connect logic happens here:
-
-          - first, we check if there was any error returned by the provider and
-            display it
-          - then we fetch the session cookie, decode and verify it
-          - the ``state`` query parameter should match with the one stored in the
-            session cookie
-          - once we known this session is legit, exchange the code with the
-            provider using the ``token_endpoint`` (see ``_exchange_code``)
+          - exchange the code with the provider using the ``token_endpoint`` (see
+            ``_exchange_code``)
           - once we have the token, use it to either extract the UserInfo from
             the ``id_token`` (``_parse_id_token``), or use the ``access_token``
             to fetch UserInfo from the ``userinfo_endpoint``
@@ -577,86 +701,12 @@ class OidcHandler:
 
         Args:
             request: the incoming request from the browser.
+            session_data: the session data, extracted from our cookie
+            code: The authorization code we got from the callback.
         """
-
-        # The provider might redirect with an error.
-        # In that case, just display it as-is.
-        if b"error" in request.args:
-            # error response from the auth server. see:
-            #  https://tools.ietf.org/html/rfc6749#section-4.1.2.1
-            #  https://openid.net/specs/openid-connect-core-1_0.html#AuthError
-            error = request.args[b"error"][0].decode()
-            description = request.args.get(b"error_description", [b""])[0].decode()
-
-            # Most of the errors returned by the provider could be due by
-            # either the provider misbehaving or Synapse being misconfigured.
-            # The only exception of that is "access_denied", where the user
-            # probably cancelled the login flow. In other cases, log those errors.
-            if error != "access_denied":
-                logger.error("Error from the OIDC provider: %s %s", error, description)
-
-            self._sso_handler.render_error(request, error, description)
-            return
-
-        # otherwise, it is presumably a successful response. see:
-        #   https://tools.ietf.org/html/rfc6749#section-4.1.2
-
-        # Fetch the session cookie
-        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
-        if session is None:
-            logger.info("No session cookie found")
-            self._sso_handler.render_error(
-                request, "missing_session", "No session cookie found"
-            )
-            return
-
-        # Remove the cookie. There is a good chance that if the callback failed
-        # once, it will fail next time and the code will already be exchanged.
-        # Removing it early avoids spamming the provider with token requests.
-        request.addCookie(
-            SESSION_COOKIE_NAME,
-            b"",
-            path="/_synapse/oidc",
-            expires="Thu, Jan 01 1970 00:00:00 UTC",
-            httpOnly=True,
-            sameSite="lax",
-        )
-
-        # Check for the state query parameter
-        if b"state" not in request.args:
-            logger.info("State parameter is missing")
-            self._sso_handler.render_error(
-                request, "invalid_request", "State parameter is missing"
-            )
-            return
-
-        state = request.args[b"state"][0].decode()
-
-        # Deserialize the session token and verify it.
-        try:
-            session_data = self._token_generator.verify_oidc_session_token(
-                session, state
-            )
-        except MacaroonDeserializationException as e:
-            logger.exception("Invalid session")
-            self._sso_handler.render_error(request, "invalid_session", str(e))
-            return
-        except MacaroonInvalidSignatureException as e:
-            logger.exception("Could not verify session")
-            self._sso_handler.render_error(request, "mismatching_session", str(e))
-            return
-
         # Exchange the code with the provider
-        if b"code" not in request.args:
-            logger.info("Code parameter is missing")
-            self._sso_handler.render_error(
-                request, "invalid_request", "Code parameter is missing"
-            )
-            return
-
-        logger.debug("Exchanging code")
-        code = request.args[b"code"][0].decode()
         try:
+            logger.debug("Exchanging code")
             token = await self._exchange_code(code)
         except OidcError as e:
             logger.exception("Could not exchange code")
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 3b756a7dc2..4c06a117d3 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -102,7 +102,6 @@ class MatrixFederationAgent:
                         pool=self._pool,
                         contextFactory=tls_client_options_factory,
                     ),
-                    self._reactor,
                     ip_blacklist=ip_blacklist,
                 ),
                 user_agent=self.user_agent,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6cfadc2b4e..a19d65ad23 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import Collection
 
 # python 3 does not have a maximum int value
@@ -412,6 +413,16 @@ class DatabasePool:
                 self._check_safe_to_upsert,
             )
 
+        # We define this sequence here so that it can be referenced from both
+        # the DataStore and PersistEventStore.
+        def get_chain_id_txn(txn):
+            txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+            return txn.fetchone()[0]
+
+        self.event_chain_id_gen = build_sequence_generator(
+            engine, get_chain_id_txn, "event_auth_chain_id"
+        )
+
     def is_running(self) -> bool:
         """Is the database pool currently running
         """
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1b6ccd51c8..c128889bf9 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -513,21 +514,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         for user_chunk in batch_iter(user_ids, 100):
             clause, params = make_in_list_sql_clause(
-                txn.database_engine, "k.user_id", user_chunk
-            )
-            sql = (
-                """
-                SELECT k.user_id, k.keytype, k.keydata, k.stream_id
-                  FROM e2e_cross_signing_keys k
-                  INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
-                                FROM e2e_cross_signing_keys
-                               GROUP BY user_id, keytype) s
-                 USING (user_id, stream_id, keytype)
-                 WHERE
-            """
-                + clause
+                txn.database_engine, "user_id", user_chunk
             )
 
+            # Fetch the latest key for each type per user.
+            if isinstance(self.database_engine, PostgresEngine):
+                # The `DISTINCT ON` clause will pick the *first* row it
+                # encounters, so ordering by stream ID desc will ensure we get
+                # the latest key.
+                sql = """
+                    SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
+                        FROM e2e_cross_signing_keys
+                        WHERE %(clause)s
+                        ORDER BY user_id, keytype, stream_id DESC
+                """ % {
+                    "clause": clause
+                }
+            else:
+                # SQLite has special handling for bare columns when using
+                # MIN/MAX with a `GROUP BY` clause where it picks the value from
+                # a row that matches the MIN/MAX.
+                sql = """
+                    SELECT user_id, keytype, keydata, MAX(stream_id)
+                        FROM e2e_cross_signing_keys
+                        WHERE %(clause)s
+                        GROUP BY user_id, keytype
+                """ % {
+                    "clause": clause
+                }
+
             txn.execute(sql, params)
             rows = self.db_pool.cursor_to_dict(txn)
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 186f064036..3216b3f3c8 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -43,7 +43,6 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
@@ -100,14 +99,6 @@ class PersistEventsStore:
         self._clock = hs.get_clock()
         self._instance_name = hs.get_instance_name()
 
-        def get_chain_id_txn(txn):
-            txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
-            return txn.fetchone()[0]
-
-        self._event_chain_id_gen = build_sequence_generator(
-            db.engine, get_chain_id_txn, "event_auth_chain_id"
-        )
-
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
@@ -466,9 +457,6 @@ class PersistEventsStore:
         if not state_events:
             return
 
-        # Map from event ID to chain ID/sequence number.
-        chain_map = {}  # type: Dict[str, Tuple[int, int]]
-
         # We need to know the type/state_key and auth events of the events we're
         # calculating chain IDs for. We don't rely on having the full Event
         # instances as we'll potentially be pulling more events from the DB and
@@ -479,19 +467,44 @@ class PersistEventsStore:
         event_to_auth_chain = {
             e.event_id: e.auth_event_ids() for e in state_events.values()
         }
+        event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+
+        self._add_chain_cover_index(
+            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+        )
+
+    @staticmethod
+    def _add_chain_cover_index(
+        txn,
+        db_pool: DatabasePool,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, List[str]],
+    ) -> None:
+        """Calculate the chain cover index for the given events.
+
+        Args:
+            event_to_room_id: Event ID to the room ID of the event
+            event_to_types: Event ID to type and state_key of the event
+            event_to_auth_chain: Event ID to list of auth event IDs of the
+                event (events with no auth events can be excluded).
+        """
+
+        # Map from event ID to chain ID/sequence number.
+        chain_map = {}  # type: Dict[str, Tuple[int, int]]
 
         # Set of event IDs to calculate chain ID/seq numbers for.
-        events_to_calc_chain_id_for = set(state_events)
+        events_to_calc_chain_id_for = set(event_to_room_id)
 
         # We check if there are any events that need to be handled in the rooms
         # we're looking at. These should just be out of band memberships, where
         # we didn't have the auth chain when we first persisted.
-        rows = self.db_pool.simple_select_many_txn(
+        rows = db_pool.simple_select_many_txn(
             txn,
             table="event_auth_chain_to_calculate",
             keyvalues={},
             column="room_id",
-            iterable={e.room_id for e in state_events.values()},
+            iterable=set(event_to_room_id.values()),
             retcols=("event_id", "type", "state_key"),
         )
         for row in rows:
@@ -502,7 +515,7 @@ class PersistEventsStore:
             # (We could pull out the auth events for all rows at once using
             # simple_select_many, but this case happens rarely and almost always
             # with a single row.)
-            auth_events = self.db_pool.simple_select_onecol_txn(
+            auth_events = db_pool.simple_select_onecol_txn(
                 txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
             )
 
@@ -551,9 +564,7 @@ class PersistEventsStore:
 
                     events_to_calc_chain_id_for.add(auth_id)
 
-                    event_to_auth_chain[
-                        auth_id
-                    ] = self.db_pool.simple_select_onecol_txn(
+                    event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn(
                         txn,
                         "event_auth",
                         keyvalues={"event_id": auth_id},
@@ -582,16 +593,17 @@ class PersistEventsStore:
                     # the list of events to calculate chain IDs for next time
                     # around. (Otherwise we will have already added it to the
                     # table).
-                    event = state_events.get(event_id)
-                    if event:
-                        self.db_pool.simple_insert_txn(
+                    room_id = event_to_room_id.get(event_id)
+                    if room_id:
+                        e_type, state_key = event_to_types[event_id]
+                        db_pool.simple_insert_txn(
                             txn,
                             table="event_auth_chain_to_calculate",
                             values={
-                                "event_id": event.event_id,
-                                "room_id": event.room_id,
-                                "type": event.type,
-                                "state_key": event.state_key,
+                                "event_id": event_id,
+                                "room_id": room_id,
+                                "type": e_type,
+                                "state_key": state_key,
                             },
                         )
 
@@ -617,7 +629,7 @@ class PersistEventsStore:
             events_to_calc_chain_id_for, event_to_auth_chain
         ):
             existing_chain_id = None
-            for auth_id in event_to_auth_chain[event_id]:
+            for auth_id in event_to_auth_chain.get(event_id, []):
                 if event_to_types.get(event_id) == event_to_types.get(auth_id):
                     existing_chain_id = chain_map[auth_id]
                     break
@@ -629,7 +641,7 @@ class PersistEventsStore:
                 proposed_new_id = existing_chain_id[0]
                 proposed_new_seq = existing_chain_id[1] + 1
                 if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
-                    already_allocated = self.db_pool.simple_select_one_onecol_txn(
+                    already_allocated = db_pool.simple_select_one_onecol_txn(
                         txn,
                         table="event_auth_chains",
                         keyvalues={
@@ -650,14 +662,14 @@ class PersistEventsStore:
                         )
 
             if not new_chain_tuple:
-                new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1)
+                new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
 
             chains_tuples_allocated.add(new_chain_tuple)
 
             chain_map[event_id] = new_chain_tuple
             new_chain_tuples[event_id] = new_chain_tuple
 
-        self.db_pool.simple_insert_many_txn(
+        db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chains",
             values=[
@@ -666,7 +678,7 @@ class PersistEventsStore:
             ],
         )
 
-        self.db_pool.simple_delete_many_txn(
+        db_pool.simple_delete_many_txn(
             txn,
             table="event_auth_chain_to_calculate",
             keyvalues={},
@@ -699,7 +711,7 @@ class PersistEventsStore:
         # Step 1, fetch all existing links from all the chains we've seen
         # referenced.
         chain_links = _LinkMap()
-        rows = self.db_pool.simple_select_many_txn(
+        rows = db_pool.simple_select_many_txn(
             txn,
             table="event_auth_chain_links",
             column="origin_chain_id",
@@ -730,11 +742,11 @@ class PersistEventsStore:
             # auth events (A, B) to check if B is reachable from A.
             reduction = {
                 a_id
-                for a_id in event_to_auth_chain[event_id]
+                for a_id in event_to_auth_chain.get(event_id, [])
                 if chain_map[a_id][0] != chain_id
             }
             for start_auth_id, end_auth_id in itertools.permutations(
-                event_to_auth_chain[event_id], r=2,
+                event_to_auth_chain.get(event_id, []), r=2,
             ):
                 if chain_links.exists_path_from(
                     chain_map[start_auth_id], chain_map[end_auth_id]
@@ -763,7 +775,7 @@ class PersistEventsStore:
                         (chain_id, sequence_number), (target_id, target_seq)
                     )
 
-        self.db_pool.simple_insert_many_txn(
+        db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chain_links",
             values=[
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 7e4b175d08..7128dc1742 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -14,13 +14,14 @@
 # limitations under the License.
 
 import logging
-from typing import List, Tuple
+from typing import Dict, List, Optional, Tuple
 
 from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 
@@ -108,6 +109,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             "rejected_events_metadata", self._rejected_events_metadata,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "chain_cover", self._chain_cover_index,
+        )
+
     async def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
@@ -706,3 +711,191 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             )
 
         return len(results)
+
+    async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
+        """A background updates that iterates over all rooms and generates the
+        chain cover index for them.
+        """
+
+        current_room_id = progress.get("current_room_id", "")
+
+        # Have we finished processing the current room.
+        finished = progress.get("finished", True)
+
+        # Where we've processed up to in the room, defaults to the start of the
+        # room.
+        last_depth = progress.get("last_depth", -1)
+        last_stream = progress.get("last_stream", -1)
+
+        # Have we set the `has_auth_chain_index` for the room yet.
+        has_set_room_has_chain_index = progress.get(
+            "has_set_room_has_chain_index", False
+        )
+
+        if finished:
+            # If we've finished with the previous room (or its our first
+            # iteration) we move on to the next room.
+
+            def _get_next_room(txn: Cursor) -> Optional[str]:
+                sql = """
+                    SELECT room_id FROM rooms
+                    WHERE room_id > ?
+                        AND (
+                            NOT has_auth_chain_index
+                            OR has_auth_chain_index IS NULL
+                        )
+                    ORDER BY room_id
+                    LIMIT 1
+                """
+                txn.execute(sql, (current_room_id,))
+                row = txn.fetchone()
+                if row:
+                    return row[0]
+
+                return None
+
+            current_room_id = await self.db_pool.runInteraction(
+                "_chain_cover_index", _get_next_room
+            )
+            if not current_room_id:
+                await self.db_pool.updates._end_background_update("chain_cover")
+                return 0
+
+            logger.debug("Adding chain cover to %s", current_room_id)
+
+        def _calculate_auth_chain(
+            txn: Cursor, last_depth: int, last_stream: int
+        ) -> Tuple[int, int, int]:
+            # Get the next set of events in the room (that we haven't already
+            # computed chain cover for). We do this in topological order.
+
+            # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+            # comparison, but that is not supported on older SQLite versions
+            tuple_clause, tuple_args = make_tuple_comparison_clause(
+                self.database_engine,
+                [
+                    ("topological_ordering", last_depth),
+                    ("stream_ordering", last_stream),
+                ],
+            )
+
+            sql = """
+                SELECT
+                    event_id, state_events.type, state_events.state_key,
+                    topological_ordering, stream_ordering
+                FROM events
+                INNER JOIN state_events USING (event_id)
+                LEFT JOIN event_auth_chains USING (event_id)
+                LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+                WHERE events.room_id = ?
+                    AND event_auth_chains.event_id IS NULL
+                    AND event_auth_chain_to_calculate.event_id IS NULL
+                    AND %(tuple_cmp)s
+                ORDER BY topological_ordering, stream_ordering
+                LIMIT ?
+            """ % {
+                "tuple_cmp": tuple_clause,
+            }
+
+            args = [current_room_id]
+            args.extend(tuple_args)
+            args.append(batch_size)
+
+            txn.execute(sql, args)
+            rows = txn.fetchall()
+
+            # Put the results in the necessary format for
+            # `_add_chain_cover_index`
+            event_to_room_id = {row[0]: current_room_id for row in rows}
+            event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+            new_last_depth = rows[-1][3] if rows else last_depth  # type: int
+            new_last_stream = rows[-1][4] if rows else last_stream  # type: int
+
+            count = len(rows)
+
+            # We also need to fetch the auth events for them.
+            auth_events = self.db_pool.simple_select_many_txn(
+                txn,
+                table="event_auth",
+                column="event_id",
+                iterable=event_to_room_id,
+                keyvalues={},
+                retcols=("event_id", "auth_id"),
+            )
+
+            event_to_auth_chain = {}  # type: Dict[str, List[str]]
+            for row in auth_events:
+                event_to_auth_chain.setdefault(row["event_id"], []).append(
+                    row["auth_id"]
+                )
+
+            # Calculate and persist the chain cover index for this set of events.
+            #
+            # Annoyingly we need to gut wrench into the persit event store so that
+            # we can reuse the function to calculate the chain cover for rooms.
+            PersistEventsStore._add_chain_cover_index(
+                txn,
+                self.db_pool,
+                event_to_room_id,
+                event_to_types,
+                event_to_auth_chain,
+            )
+
+            return new_last_depth, new_last_stream, count
+
+        last_depth, last_stream, count = await self.db_pool.runInteraction(
+            "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+        )
+
+        total_rows_processed = count
+
+        if count < batch_size and not has_set_room_has_chain_index:
+            # If we've done all the events in the room we flip the
+            # `has_auth_chain_index` in the DB. Note that its possible for
+            # further events to be persisted between the above and setting the
+            # flag without having the chain cover calculated for them. This is
+            # fine as a) the code gracefully handles these cases and b) we'll
+            # calculate them below.
+
+            await self.db_pool.simple_update(
+                table="rooms",
+                keyvalues={"room_id": current_room_id},
+                updatevalues={"has_auth_chain_index": True},
+                desc="_chain_cover_index",
+            )
+            has_set_room_has_chain_index = True
+
+            # Handle any events that might have raced with us flipping the
+            # bit above.
+            last_depth, last_stream, count = await self.db_pool.runInteraction(
+                "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+            )
+
+            total_rows_processed += count
+
+            # Note that at this point its technically possible that more events
+            # than our `batch_size` have been persisted without their chain
+            # cover, so we need to continue processing this room if the last
+            # count returned was equal to the `batch_size`.
+
+        if count < batch_size:
+            # We've finished calculating the index for this room, move on to the
+            # next room.
+            await self.db_pool.updates._background_update_progress(
+                "chain_cover", {"current_room_id": current_room_id, "finished": True},
+            )
+        else:
+            # We still have outstanding events to calculate the index for.
+            await self.db_pool.updates._background_update_progress(
+                "chain_cover",
+                {
+                    "current_room_id": current_room_id,
+                    "last_depth": last_depth,
+                    "last_stream": last_stream,
+                    "has_auth_chain_index": has_set_room_has_chain_index,
+                    "finished": False,
+                },
+            )
+
+        return total_rows_processed
diff --git a/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
new file mode 100644
index 0000000000..fe3dca71dd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+  (5906, 'chain_cover', '{}', 'rejected_events_metadata');
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 59207cadd4..cea595ff19 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -464,19 +464,17 @@ class TransactionStore(TransactionWorkerStore):
         txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
     ) -> List[str]:
         q = """
-            SELECT destination FROM destinations
-                WHERE destination IN (
-                    SELECT destination FROM destination_rooms
-                        WHERE destination_rooms.stream_ordering >
-                            destinations.last_successful_stream_ordering
-                )
-                AND destination > ?
-                AND (
-                    retry_last_ts IS NULL OR
-                    retry_last_ts + retry_interval < ?
-                )
-                ORDER BY destination
-                LIMIT 25
+            SELECT DISTINCT destination FROM destinations
+            INNER JOIN destination_rooms USING (destination)
+                WHERE
+                    stream_ordering > last_successful_stream_ordering
+                    AND destination > ?
+                    AND (
+                        retry_last_ts IS NULL OR
+                        retry_last_ts + retry_interval < ?
+                    )
+                    ORDER BY destination
+                    LIMIT 25
         """
         txn.execute(
             q,
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 4386b6101e..412df6b8ef 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -15,9 +15,8 @@
 import abc
 import logging
 import threading
-from typing import Callable, List, Optional
+from typing import TYPE_CHECKING, Callable, List, Optional
 
-from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import (
     BaseDatabaseEngine,
     IncorrectDatabaseSetup,
@@ -25,6 +24,9 @@ from synapse.storage.engines import (
 )
 from synapse.storage.types import Connection, Cursor
 
+if TYPE_CHECKING:
+    from synapse.storage.database import LoggingDatabaseConnection
+
 logger = logging.getLogger(__name__)
 
 
@@ -55,7 +57,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
     @abc.abstractmethod
     def check_consistency(
         self,
-        db_conn: LoggingDatabaseConnection,
+        db_conn: "LoggingDatabaseConnection",
         table: str,
         id_column: str,
         positive: bool = True,
@@ -88,7 +90,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
 
     def check_consistency(
         self,
-        db_conn: LoggingDatabaseConnection,
+        db_conn: "LoggingDatabaseConnection",
         table: str,
         id_column: str,
         positive: bool = True,
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index f7b4857a84..6ef2b008a4 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -92,7 +92,7 @@ def sorted_topologically(
         node = heapq.heappop(zero_degree)
         yield node
 
-        for edge in reverse_graph[node]:
+        for edge in reverse_graph.get(node, []):
             if edge in degree_map:
                 degree_map[edge] -= 1
                 if degree_map[edge] == 0:
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 2abd7a83b5..5d338bea87 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -151,6 +151,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
 
         self.handler = hs.get_oidc_handler()
+        self.provider = self.handler._provider
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
@@ -162,9 +163,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return hs
 
     def metadata_edit(self, values):
-        return patch.dict(self.handler._provider_metadata, values)
+        return patch.dict(self.provider._provider_metadata, values)
 
     def assertRenderedError(self, error, error_description=None):
+        self.render_error.assert_called_once()
         args = self.render_error.call_args[0]
         self.assertEqual(args[1], error)
         if error_description is not None:
@@ -175,15 +177,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
-        self.assertEqual(self.handler._callback_url, CALLBACK_URL)
-        self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
-        self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+        self.assertEqual(self.provider._callback_url, CALLBACK_URL)
+        self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
+        self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
     @override_config({"oidc_config": {"discover": True}})
     def test_discovery(self):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
-        metadata = self.get_success(self.handler.load_metadata())
+        metadata = self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
 
         self.assertEqual(metadata.issuer, ISSUER)
@@ -195,47 +197,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # subsequent calls should be cached
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_no_discovery(self):
         """When discovery is disabled, it should not try to load from discovery document."""
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_load_jwks(self):
         """JWKS loading is done once (then cached) if used."""
-        jwks = self.get_success(self.handler.load_jwks())
+        jwks = self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
         self.assertEqual(jwks, {"keys": []})
 
         # subsequent calls should be cached…
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks())
+        self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_not_called()
 
         # …unless forced
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks(force=True))
+        self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
 
         # Throw if the JWKS uri is missing
         with self.metadata_edit({"jwks_uri": None}):
-            self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
+            self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
         # Return empty key set if JWKS are not used
-        self.handler._scopes = []  # not asking the openid scope
+        self.provider._scopes = []  # not asking the openid scope
         self.http_client.get_json.reset_mock()
-        jwks = self.get_success(self.handler.load_jwks(force=True))
+        jwks = self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
-        h = self.handler
+        h = self.provider
 
         # Default test config does not throw
         h._validate_metadata()
@@ -314,13 +316,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
-            self.handler._validate_metadata()
+            self.provider._validate_metadata()
 
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["addCookie"])
         url = self.get_success(
-            self.handler.handle_redirect_request(req, b"http://client/redirect")
+            self.provider.handle_redirect_request(req, b"http://client/redirect")
         )
         url = urlparse(url)
         auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -388,7 +390,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # ensure that we are correctly testing the fallback when "get_extra_attributes"
         # is not implemented.
-        mapping_provider = self.handler._user_mapping_provider
+        mapping_provider = self.provider._user_mapping_provider
         with self.assertRaises(AttributeError):
             _ = mapping_provider.get_extra_attributes
 
@@ -403,9 +405,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": username,
         }
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
@@ -425,14 +427,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         auth_handler.complete_sso_login.assert_called_once_with(
             expected_user_id, request, client_redirect_url, None,
         )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._fetch_userinfo.assert_not_called()
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
+        self.provider._fetch_userinfo.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
         with patch.object(
-            self.handler,
+            self.provider,
             "_remote_id_from_userinfo",
             new=Mock(side_effect=MappingException()),
         ):
@@ -440,36 +442,36 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.handler._parse_id_token = simple_async_mock(raises=Exception())
+        self.provider._parse_id_token = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
         auth_handler.complete_sso_login.reset_mock()
-        self.handler._exchange_code.reset_mock()
-        self.handler._parse_id_token.reset_mock()
-        self.handler._fetch_userinfo.reset_mock()
+        self.provider._exchange_code.reset_mock()
+        self.provider._parse_id_token.reset_mock()
+        self.provider._fetch_userinfo.reset_mock()
 
         # With userinfo fetching
-        self.handler._scopes = []  # do not ask the "openid" scope
+        self.provider._scopes = []  # do not ask the "openid" scope
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
             expected_user_id, request, client_redirect_url, None,
         )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_not_called()
-        self.handler._fetch_userinfo.assert_called_once_with(token)
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_not_called()
+        self.provider._fetch_userinfo.assert_called_once_with(token)
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
         from synapse.handlers.oidc_handler import OidcError
 
-        self.handler._exchange_code = simple_async_mock(
+        self.provider._exchange_code = simple_async_mock(
             raises=OidcError("invalid_request")
         )
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -524,7 +526,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
         )
         code = "code"
-        ret = self.get_success(self.handler._exchange_code(code))
+        ret = self.get_success(self.provider._exchange_code(code))
         kwargs = self.http_client.request.call_args[1]
 
         self.assertEqual(ret, token)
@@ -548,7 +550,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         from synapse.handlers.oidc_handler import OidcError
 
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "foo")
         self.assertEqual(exc.value.error_description, "bar")
 
@@ -558,7 +560,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=500, phrase=b"Internal Server Error", body=b"Not JSON",
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
@@ -570,14 +572,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             )
         )
 
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
         self.http_client.request = simple_async_mock(
             return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
@@ -586,7 +588,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=200, phrase=b"OK", body=b'{"error": "some_error"}',
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "some_error")
 
     @override_config(
@@ -612,8 +614,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "foo",
             "phone": "1234567",
         }
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
@@ -979,9 +981,10 @@ async def _make_callback_with_userinfo(
     from synapse.handlers.oidc_handler import OidcSessionData
 
     handler = hs.get_oidc_handler()
-    handler._exchange_code = simple_async_mock(return_value={})
-    handler._parse_id_token = simple_async_mock(return_value=userinfo)
-    handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+    provider = handler._provider
+    provider._exchange_code = simple_async_mock(return_value={})
+    provider._parse_id_token = simple_async_mock(return_value=userinfo)
+    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
 
     state = "state"
     session = handler._token_generator.generate_oidc_session_token(
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 83c377824b..ff67a73749 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -20,7 +20,10 @@ from twisted.trial import unittest
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
 from synapse.storage.databases.main.events import _LinkMap
+from synapse.types import create_requester
 
 from tests.unittest import HomeserverTestCase
 
@@ -470,3 +473,114 @@ class LinkMapTestCase(unittest.TestCase):
         self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
 
         self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
+
+
+class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def test_background_update(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create a room
+        user_id = self.register_user("foo", "pass")
+        token = self.login("foo", "pass")
+        room_id = self.helper.create_room_as(user_id, tok=token)
+        requester = create_requester(user_id)
+
+        store = self.hs.get_datastore()
+
+        # Mark the room as not having a chain cover index
+        self.get_success(
+            store.db_pool.simple_update(
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": False},
+                desc="test",
+            )
+        )
+
+        # Create a fork in the DAG with different events.
+        event_handler = self.hs.get_event_creation_handler()
+        latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
+        event, context = self.get_success(
+            event_handler.create_event(
+                requester,
+                {
+                    "type": "some_state_type",
+                    "state_key": "",
+                    "content": {},
+                    "room_id": room_id,
+                    "sender": user_id,
+                },
+                prev_event_ids=latest_event_ids,
+            )
+        )
+        self.get_success(
+            event_handler.handle_new_client_event(requester, event, context)
+        )
+        state1 = list(self.get_success(context.get_current_state_ids()).values())
+
+        event, context = self.get_success(
+            event_handler.create_event(
+                requester,
+                {
+                    "type": "some_state_type",
+                    "state_key": "",
+                    "content": {},
+                    "room_id": room_id,
+                    "sender": user_id,
+                },
+                prev_event_ids=latest_event_ids,
+            )
+        )
+        self.get_success(
+            event_handler.handle_new_client_event(requester, event, context)
+        )
+        state2 = list(self.get_success(context.get_current_state_ids()).values())
+
+        # Delete the chain cover info.
+
+        def _delete_tables(txn):
+            txn.execute("DELETE FROM event_auth_chains")
+            txn.execute("DELETE FROM event_auth_chain_links")
+
+        self.get_success(store.db_pool.runInteraction("test", _delete_tables))
+
+        # Insert and run the background update.
+        self.get_success(
+            store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        store.db_pool.updates._all_done = False
+
+        while not self.get_success(
+            store.db_pool.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            store.db_pool.runInteraction(
+                "test",
+                store._get_auth_chain_difference_using_cover_index_txn,
+                room_id,
+                [state1, state2],
+            )
+        )
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 1184cea5a3..522c8061f9 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -56,6 +56,14 @@ class SortTopologically(TestCase):
         graph = {}  # type: Dict[int, List[int]]
         self.assertEqual(list(sorted_topologically([], graph)), [])
 
+    def test_handle_empty_graph(self):
+        "Test that a graph where a node doesn't have an entry is treated as empty"
+
+        graph = {}  # type: Dict[int, List[int]]
+
+        # For disconnected nodes the output is simply sorted.
+        self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
     def test_disconnected(self):
         "Test that a graph with no edges work"