summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2021-06-02 11:38:54 -0400
committerPatrick Cloke <patrickc@matrix.org>2021-06-02 11:38:54 -0400
commit09361655d2fcdac24642efa2b60122dd4f0684be (patch)
treea8a245f6185f6102b47b6b35fbd33a47d8447571 /synapse
parentMerge remote-tracking branch 'origin/release-v1.35' into matrix-org-hotfixes (diff)
parentRewrite the KeyRing (#10035) (diff)
downloadsynapse-09361655d2fcdac24642efa2b60122dd4f0684be.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py8
-rw-r--r--synapse/app/generic_worker.py4
-rw-r--r--synapse/crypto/keyring.py642
-rw-r--r--synapse/federation/transport/server.py7
-rw-r--r--synapse/groups/attestations.py4
-rw-r--r--synapse/handlers/federation.py12
-rw-r--r--synapse/handlers/space_summary.py19
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/http/servlet.py196
-rw-r--r--synapse/logging/opentracing.py31
-rw-r--r--synapse/metrics/background_process_metrics.py10
-rw-r--r--synapse/rest/client/v1/room.py8
-rw-r--r--synapse/rest/client/v2_alpha/report_event.py13
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py9
-rw-r--r--synapse/storage/databases/main/cache.py1
-rw-r--r--synapse/storage/databases/main/events_worker.py61
-rw-r--r--synapse/storage/databases/main/purge_events.py26
-rw-r--r--synapse/storage/databases/main/room.py2
19 files changed, 574 insertions, 485 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 4591246bd1..d9843a1708 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.35.0rc3"
+__version__ = "1.35.0"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 458306eba5..26a3b38918 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -206,11 +206,11 @@ class Auth:
                 requester = create_requester(user_id, app_service=app_service)
 
                 request.requester = user_id
+                if user_id in self._force_tracing_for_users:
+                    opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
                 opentracing.set_tag("authenticated_entity", user_id)
                 opentracing.set_tag("user_id", user_id)
                 opentracing.set_tag("appservice_id", app_service.id)
-                if user_id in self._force_tracing_for_users:
-                    opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
 
                 return requester
 
@@ -259,12 +259,12 @@ class Auth:
             )
 
             request.requester = requester
+            if user_info.token_owner in self._force_tracing_for_users:
+                opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
             opentracing.set_tag("authenticated_entity", user_info.token_owner)
             opentracing.set_tag("user_id", user_info.user_id)
             if device_id:
                 opentracing.set_tag("device_id", device_id)
-            if user_info.token_owner in self._force_tracing_for_users:
-                opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
 
             return requester
         except KeyError:
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 91ad326f19..57c2fc2e88 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -109,7 +109,7 @@ from synapse.storage.databases.main.monthly_active_users import (
     MonthlyActiveUsersWorkerStore,
 )
 from synapse.storage.databases.main.presence import PresenceStore
-from synapse.storage.databases.main.search import SearchWorkerStore
+from synapse.storage.databases.main.search import SearchStore
 from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.databases.main.transactions import TransactionWorkerStore
 from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
@@ -242,7 +242,7 @@ class GenericWorkerSlavedStore(
     MonthlyActiveUsersWorkerStore,
     MediaRepositoryStore,
     ServerMetricsStore,
-    SearchWorkerStore,
+    SearchStore,
     TransactionWorkerStore,
     BaseSlavedStore,
 ):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 6fc0712978..c840ffca71 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -16,8 +16,7 @@
 import abc
 import logging
 import urllib
-from collections import defaultdict
-from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from signedjson.key import (
@@ -44,17 +43,12 @@ from synapse.api.errors import (
 from synapse.config.key import TrustedKeyServer
 from synapse.events import EventBase
 from synapse.events.utils import prune_event_dict
-from synapse.logging.context import (
-    PreserveLoggingContext,
-    make_deferred_yieldable,
-    preserve_fn,
-    run_in_background,
-)
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage.keys import FetchKeyResult
 from synapse.types import JsonDict
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import yieldable_gather_results
-from synapse.util.metrics import Measure
+from synapse.util.batching_queue import BatchingQueue
 from synapse.util.retryutils import NotRetryingDestination
 
 if TYPE_CHECKING:
@@ -80,32 +74,19 @@ class VerifyJsonRequest:
         minimum_valid_until_ts: time at which we require the signing key to
             be valid. (0 implies we don't care)
 
-        request_name: The name of the request.
-
         key_ids: The set of key_ids to that could be used to verify the JSON object
-
-        key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
-            A deferred (server_name, key_id, verify_key) tuple that resolves when
-            a verify key has been fetched. The deferreds' callbacks are run with no
-            logcontext.
-
-            If we are unable to find a key which satisfies the request, the deferred
-            errbacks with an M_UNAUTHORIZED SynapseError.
     """
 
     server_name = attr.ib(type=str)
     get_json_object = attr.ib(type=Callable[[], JsonDict])
     minimum_valid_until_ts = attr.ib(type=int)
-    request_name = attr.ib(type=str)
     key_ids = attr.ib(type=List[str])
-    key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
 
     @staticmethod
     def from_json_object(
         server_name: str,
         json_object: JsonDict,
         minimum_valid_until_ms: int,
-        request_name: str,
     ):
         """Create a VerifyJsonRequest to verify all signatures on a signed JSON
         object for the given server.
@@ -115,7 +96,6 @@ class VerifyJsonRequest:
             server_name,
             lambda: json_object,
             minimum_valid_until_ms,
-            request_name=request_name,
             key_ids=key_ids,
         )
 
@@ -135,16 +115,48 @@ class VerifyJsonRequest:
             # memory than the Event object itself.
             lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
             minimum_valid_until_ms,
-            request_name=event.event_id,
             key_ids=key_ids,
         )
 
+    def to_fetch_key_request(self) -> "_FetchKeyRequest":
+        """Create a key fetch request for all keys needed to satisfy the
+        verification request.
+        """
+        return _FetchKeyRequest(
+            server_name=self.server_name,
+            minimum_valid_until_ts=self.minimum_valid_until_ts,
+            key_ids=self.key_ids,
+        )
+
 
 class KeyLookupError(ValueError):
     pass
 
 
+@attr.s(slots=True)
+class _FetchKeyRequest:
+    """A request for keys for a given server.
+
+    We will continue to try and fetch until we have all the keys listed under
+    `key_ids` (with an appropriate `valid_until_ts` property) or we run out of
+    places to fetch keys from.
+
+    Attributes:
+        server_name: The name of the server that owns the keys.
+        minimum_valid_until_ts: The timestamp which the keys must be valid until.
+        key_ids: The IDs of the keys to attempt to fetch
+    """
+
+    server_name = attr.ib(type=str)
+    minimum_valid_until_ts = attr.ib(type=int)
+    key_ids = attr.ib(type=List[str])
+
+
 class Keyring:
+    """Handles verifying signed JSON objects and fetching the keys needed to do
+    so.
+    """
+
     def __init__(
         self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
     ):
@@ -158,22 +170,22 @@ class Keyring:
             )
         self._key_fetchers = key_fetchers
 
-        # map from server name to Deferred. Has an entry for each server with
-        # an ongoing key download; the Deferred completes once the download
-        # completes.
-        #
-        # These are regular, logcontext-agnostic Deferreds.
-        self.key_downloads = {}  # type: Dict[str, defer.Deferred]
+        self._server_queue = BatchingQueue(
+            "keyring_server",
+            clock=hs.get_clock(),
+            process_batch_callback=self._inner_fetch_key_requests,
+        )  # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]]
 
-    def verify_json_for_server(
+    async def verify_json_for_server(
         self,
         server_name: str,
         json_object: JsonDict,
         validity_time: int,
-        request_name: str,
-    ) -> defer.Deferred:
+    ) -> None:
         """Verify that a JSON object has been signed by a given server
 
+        Completes if the the object was correctly signed, otherwise raises.
+
         Args:
             server_name: name of the server which must have signed this object
 
@@ -181,52 +193,45 @@ class Keyring:
 
             validity_time: timestamp at which we require the signing key to
                 be valid. (0 implies we don't care)
-
-            request_name: an identifier for this json object (eg, an event id)
-                for logging.
-
-        Returns:
-            Deferred[None]: completes if the the object was correctly signed, otherwise
-                errbacks with an error
         """
         request = VerifyJsonRequest.from_json_object(
             server_name,
             json_object,
             validity_time,
-            request_name,
         )
-        requests = (request,)
-        return make_deferred_yieldable(self._verify_objects(requests)[0])
+        return await self.process_request(request)
 
     def verify_json_objects_for_server(
-        self, server_and_json: Iterable[Tuple[str, dict, int, str]]
+        self, server_and_json: Iterable[Tuple[str, dict, int]]
     ) -> List[defer.Deferred]:
         """Bulk verifies signatures of json objects, bulk fetching keys as
         necessary.
 
         Args:
             server_and_json:
-                Iterable of (server_name, json_object, validity_time, request_name)
+                Iterable of (server_name, json_object, validity_time)
                 tuples.
 
                 validity_time is a timestamp at which the signing key must be
                 valid.
 
-                request_name is an identifier for this json object (eg, an event id)
-                for logging.
-
         Returns:
             List<Deferred[None]>: for each input triplet, a deferred indicating success
                 or failure to verify each json object's signature for the given
                 server_name. The deferreds run their callbacks in the sentinel
                 logcontext.
         """
-        return self._verify_objects(
-            VerifyJsonRequest.from_json_object(
-                server_name, json_object, validity_time, request_name
+        return [
+            run_in_background(
+                self.process_request,
+                VerifyJsonRequest.from_json_object(
+                    server_name,
+                    json_object,
+                    validity_time,
+                ),
             )
-            for server_name, json_object, validity_time, request_name in server_and_json
-        )
+            for server_name, json_object, validity_time in server_and_json
+        ]
 
     def verify_events_for_server(
         self, server_and_events: Iterable[Tuple[str, EventBase, int]]
@@ -252,321 +257,223 @@ class Keyring:
                 server_name. The deferreds run their callbacks in the sentinel
                 logcontext.
         """
-        return self._verify_objects(
-            VerifyJsonRequest.from_event(server_name, event, validity_time)
+        return [
+            run_in_background(
+                self.process_request,
+                VerifyJsonRequest.from_event(
+                    server_name,
+                    event,
+                    validity_time,
+                ),
+            )
             for server_name, event, validity_time in server_and_events
-        )
-
-    def _verify_objects(
-        self, verify_requests: Iterable[VerifyJsonRequest]
-    ) -> List[defer.Deferred]:
-        """Does the work of verify_json_[objects_]for_server
-
-
-        Args:
-            verify_requests: Iterable of verification requests.
+        ]
 
-        Returns:
-            List<Deferred[None]>: for each input item, a deferred indicating success
-                or failure to verify each json object's signature for the given
-                server_name. The deferreds run their callbacks in the sentinel
-                logcontext.
+    async def process_request(self, verify_request: VerifyJsonRequest) -> None:
+        """Processes the `VerifyJsonRequest`. Raises if the object is not signed
+        by the server, the signatures don't match or we failed to fetch the
+        necessary keys.
         """
-        # a list of VerifyJsonRequests which are awaiting a key lookup
-        key_lookups = []
-        handle = preserve_fn(_handle_key_deferred)
-
-        def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
-            """Process an entry in the request list
-
-            Adds a key request to key_lookups, and returns a deferred which
-            will complete or fail (in the sentinel context) when verification completes.
-            """
-            if not verify_request.key_ids:
-                return defer.fail(
-                    SynapseError(
-                        400,
-                        "Not signed by %s" % (verify_request.server_name,),
-                        Codes.UNAUTHORIZED,
-                    )
-                )
 
-            logger.debug(
-                "Verifying %s for %s with key_ids %s, min_validity %i",
-                verify_request.request_name,
-                verify_request.server_name,
-                verify_request.key_ids,
-                verify_request.minimum_valid_until_ts,
+        if not verify_request.key_ids:
+            raise SynapseError(
+                400,
+                f"Not signed by {verify_request.server_name}",
+                Codes.UNAUTHORIZED,
             )
 
-            # add the key request to the queue, but don't start it off yet.
-            key_lookups.append(verify_request)
-
-            # now run _handle_key_deferred, which will wait for the key request
-            # to complete and then do the verification.
-            #
-            # We want _handle_key_request to log to the right context, so we
-            # wrap it with preserve_fn (aka run_in_background)
-            return handle(verify_request)
-
-        results = [process(r) for r in verify_requests]
-
-        if key_lookups:
-            run_in_background(self._start_key_lookups, key_lookups)
-
-        return results
-
-    async def _start_key_lookups(
-        self, verify_requests: List[VerifyJsonRequest]
-    ) -> None:
-        """Sets off the key fetches for each verify request
-
-        Once each fetch completes, verify_request.key_ready will be resolved.
-
-        Args:
-            verify_requests:
-        """
-
-        try:
-            # map from server name to a set of outstanding request ids
-            server_to_request_ids = {}  # type: Dict[str, Set[int]]
-
-            for verify_request in verify_requests:
-                server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids.setdefault(server_name, set()).add(request_id)
-
-            # Wait for any previous lookups to complete before proceeding.
-            await self.wait_for_previous_lookups(server_to_request_ids.keys())
-
-            # take out a lock on each of the servers by sticking a Deferred in
-            # key_downloads
-            for server_name in server_to_request_ids.keys():
-                self.key_downloads[server_name] = defer.Deferred()
-                logger.debug("Got key lookup lock on %s", server_name)
-
-            # When we've finished fetching all the keys for a given server_name,
-            # drop the lock by resolving the deferred in key_downloads.
-            def drop_server_lock(server_name):
-                d = self.key_downloads.pop(server_name)
-                d.callback(None)
-
-            def lookup_done(res, verify_request):
-                server_name = verify_request.server_name
-                server_requests = server_to_request_ids[server_name]
-                server_requests.remove(id(verify_request))
-
-                # if there are no more requests for this server, we can drop the lock.
-                if not server_requests:
-                    logger.debug("Releasing key lookup lock on %s", server_name)
-                    drop_server_lock(server_name)
-
-                return res
+        # Add the keys we need to verify to the queue for retrieval. We queue
+        # up requests for the same server so we don't end up with many in flight
+        # requests for the same keys.
+        key_request = verify_request.to_fetch_key_request()
+        found_keys_by_server = await self._server_queue.add_to_queue(
+            key_request, key=verify_request.server_name
+        )
 
-            for verify_request in verify_requests:
-                verify_request.key_ready.addBoth(lookup_done, verify_request)
+        # Since we batch up requests the returned set of keys may contain keys
+        # from other servers, so we pull out only the ones we care about.s
+        found_keys = found_keys_by_server.get(verify_request.server_name, {})
 
-            # Actually start fetching keys.
-            self._get_server_verify_keys(verify_requests)
-        except Exception:
-            logger.exception("Error starting key lookups")
+        # Verify each signature we got valid keys for, raising if we can't
+        # verify any of them.
+        verified = False
+        for key_id in verify_request.key_ids:
+            key_result = found_keys.get(key_id)
+            if not key_result:
+                continue
 
-    async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
-        """Waits for any previous key lookups for the given servers to finish.
+            if key_result.valid_until_ts < verify_request.minimum_valid_until_ts:
+                continue
 
-        Args:
-            server_names: list of servers which we want to look up
+            verify_key = key_result.verify_key
+            json_object = verify_request.get_json_object()
+            try:
+                verify_signed_json(
+                    json_object,
+                    verify_request.server_name,
+                    verify_key,
+                )
+                verified = True
+            except SignatureVerifyException as e:
+                logger.debug(
+                    "Error verifying signature for %s:%s:%s with key %s: %s",
+                    verify_request.server_name,
+                    verify_key.alg,
+                    verify_key.version,
+                    encode_verify_key_base64(verify_key),
+                    str(e),
+                )
+                raise SynapseError(
+                    401,
+                    "Invalid signature for server %s with key %s:%s: %s"
+                    % (
+                        verify_request.server_name,
+                        verify_key.alg,
+                        verify_key.version,
+                        str(e),
+                    ),
+                    Codes.UNAUTHORIZED,
+                )
 
-        Returns:
-            Resolves once all key lookups for the given servers have
-                completed. Follows the synapse rules of logcontext preservation.
-        """
-        loop_count = 1
-        while True:
-            wait_on = [
-                (server_name, self.key_downloads[server_name])
-                for server_name in server_names
-                if server_name in self.key_downloads
-            ]
-            if not wait_on:
-                break
-            logger.info(
-                "Waiting for existing lookups for %s to complete [loop %i]",
-                [w[0] for w in wait_on],
-                loop_count,
+        if not verified:
+            raise SynapseError(
+                401,
+                f"Failed to find any key to satisfy: {key_request}",
+                Codes.UNAUTHORIZED,
             )
-            with PreserveLoggingContext():
-                await defer.DeferredList((w[1] for w in wait_on))
 
-            loop_count += 1
+    async def _inner_fetch_key_requests(
+        self, requests: List[_FetchKeyRequest]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
+        """Processing function for the queue of `_FetchKeyRequest`."""
+
+        logger.debug("Starting fetch for %s", requests)
+
+        # First we need to deduplicate requests for the same key. We do this by
+        # taking the *maximum* requested `minimum_valid_until_ts` for each pair
+        # of server name/key ID.
+        server_to_key_to_ts = {}  # type: Dict[str, Dict[str, int]]
+        for request in requests:
+            by_server = server_to_key_to_ts.setdefault(request.server_name, {})
+            for key_id in request.key_ids:
+                existing_ts = by_server.get(key_id, 0)
+                by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts)
+
+        deduped_requests = [
+            _FetchKeyRequest(server_name, minimum_valid_ts, [key_id])
+            for server_name, by_server in server_to_key_to_ts.items()
+            for key_id, minimum_valid_ts in by_server.items()
+        ]
+
+        logger.debug("Deduplicated key requests to %s", deduped_requests)
+
+        # For each key we call `_inner_verify_request` which will handle
+        # fetching each key. Note these shouldn't throw if we fail to contact
+        # other servers etc.
+        results_per_request = await yieldable_gather_results(
+            self._inner_fetch_key_request,
+            deduped_requests,
+        )
 
-    def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
-        """Tries to find at least one key for each verify request
+        # We now convert the returned list of results into a map from server
+        # name to key ID to FetchKeyResult, to return.
+        to_return = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
+        for (request, results) in zip(deduped_requests, results_per_request):
+            to_return_by_server = to_return.setdefault(request.server_name, {})
+            for key_id, key_result in results.items():
+                existing = to_return_by_server.get(key_id)
+                if not existing or existing.valid_until_ts < key_result.valid_until_ts:
+                    to_return_by_server[key_id] = key_result
 
-        For each verify_request, verify_request.key_ready is called back with
-        params (server_name, key_id, VerifyKey) if a key is found, or errbacked
-        with a SynapseError if none of the keys are found.
+        return to_return
 
-        Args:
-            verify_requests: list of verify requests
+    async def _inner_fetch_key_request(
+        self, verify_request: _FetchKeyRequest
+    ) -> Dict[str, FetchKeyResult]:
+        """Attempt to fetch the given key by calling each key fetcher one by
+        one.
         """
+        logger.debug("Starting fetch for %s", verify_request)
 
-        remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
+        found_keys: Dict[str, FetchKeyResult] = {}
+        missing_key_ids = set(verify_request.key_ids)
 
-        async def do_iterations():
-            try:
-                with Measure(self.clock, "get_server_verify_keys"):
-                    for f in self._key_fetchers:
-                        if not remaining_requests:
-                            return
-                        await self._attempt_key_fetches_with_fetcher(
-                            f, remaining_requests
-                        )
-
-                    # look for any requests which weren't satisfied
-                    while remaining_requests:
-                        verify_request = remaining_requests.pop()
-                        rq_str = (
-                            "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
-                            % (
-                                verify_request.server_name,
-                                verify_request.key_ids,
-                                verify_request.minimum_valid_until_ts,
-                            )
-                        )
-
-                        # If we run the errback immediately, it may cancel our
-                        # loggingcontext while we are still in it, so instead we
-                        # schedule it for the next time round the reactor.
-                        #
-                        # (this also ensures that we don't get a stack overflow if we
-                        # has a massive queue of lookups waiting for this server).
-                        self.clock.call_later(
-                            0,
-                            verify_request.key_ready.errback,
-                            SynapseError(
-                                401,
-                                "Failed to find any key to satisfy %s" % (rq_str,),
-                                Codes.UNAUTHORIZED,
-                            ),
-                        )
-            except Exception as err:
-                # we don't really expect to get here, because any errors should already
-                # have been caught and logged. But if we do, let's log the error and make
-                # sure that all of the deferreds are resolved.
-                logger.error("Unexpected error in _get_server_verify_keys: %s", err)
-                with PreserveLoggingContext():
-                    for verify_request in remaining_requests:
-                        if not verify_request.key_ready.called:
-                            verify_request.key_ready.errback(err)
-
-        run_in_background(do_iterations)
-
-    async def _attempt_key_fetches_with_fetcher(
-        self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
-    ):
-        """Use a key fetcher to attempt to satisfy some key requests
+        for fetcher in self._key_fetchers:
+            if not missing_key_ids:
+                break
 
-        Args:
-            fetcher: fetcher to use to fetch the keys
-            remaining_requests: outstanding key requests.
-                Any successfully-completed requests will be removed from the list.
-        """
-        # The keys to fetch.
-        # server_name -> key_id -> min_valid_ts
-        missing_keys = defaultdict(dict)  # type: Dict[str, Dict[str, int]]
-
-        for verify_request in remaining_requests:
-            # any completed requests should already have been removed
-            assert not verify_request.key_ready.called
-            keys_for_server = missing_keys[verify_request.server_name]
-
-            for key_id in verify_request.key_ids:
-                # If we have several requests for the same key, then we only need to
-                # request that key once, but we should do so with the greatest
-                # min_valid_until_ts of the requests, so that we can satisfy all of
-                # the requests.
-                keys_for_server[key_id] = max(
-                    keys_for_server.get(key_id, -1),
-                    verify_request.minimum_valid_until_ts,
-                )
+            logger.debug("Getting keys from %s for %s", fetcher, verify_request)
+            keys = await fetcher.get_keys(
+                verify_request.server_name,
+                list(missing_key_ids),
+                verify_request.minimum_valid_until_ts,
+            )
 
-        results = await fetcher.get_keys(missing_keys)
+            for key_id, key in keys.items():
+                if not key:
+                    continue
 
-        completed = []
-        for verify_request in remaining_requests:
-            server_name = verify_request.server_name
+                # If we already have a result for the given key ID we keep the
+                # one with the highest `valid_until_ts`.
+                existing_key = found_keys.get(key_id)
+                if existing_key:
+                    if key.valid_until_ts <= existing_key.valid_until_ts:
+                        continue
 
-            # see if any of the keys we got this time are sufficient to
-            # complete this VerifyJsonRequest.
-            result_keys = results.get(server_name, {})
-            for key_id in verify_request.key_ids:
-                fetch_key_result = result_keys.get(key_id)
-                if not fetch_key_result:
-                    # we didn't get a result for this key
-                    continue
+                # We always store the returned key even if it doesn't the
+                # `minimum_valid_until_ts` requirement, as some verification
+                # requests may still be able to be satisfied by it.
+                #
+                # We still keep looking for the key from other fetchers in that
+                # case though.
+                found_keys[key_id] = key
 
-                if (
-                    fetch_key_result.valid_until_ts
-                    < verify_request.minimum_valid_until_ts
-                ):
-                    # key was not valid at this point
+                if key.valid_until_ts < verify_request.minimum_valid_until_ts:
                     continue
 
-                # we have a valid key for this request. If we run the callback
-                # immediately, it may cancel our loggingcontext while we are still in
-                # it, so instead we schedule it for the next time round the reactor.
-                #
-                # (this also ensures that we don't get a stack overflow if we had
-                # a massive queue of lookups waiting for this server).
-                logger.debug(
-                    "Found key %s:%s for %s",
-                    server_name,
-                    key_id,
-                    verify_request.request_name,
-                )
-                self.clock.call_later(
-                    0,
-                    verify_request.key_ready.callback,
-                    (server_name, key_id, fetch_key_result.verify_key),
-                )
-                completed.append(verify_request)
-                break
+                missing_key_ids.discard(key_id)
 
-        remaining_requests.difference_update(completed)
+        return found_keys
 
 
 class KeyFetcher(metaclass=abc.ABCMeta):
-    @abc.abstractmethod
+    def __init__(self, hs: "HomeServer"):
+        self._queue = BatchingQueue(
+            self.__class__.__name__, hs.get_clock(), self._fetch_keys
+        )
+
     async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
-    ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """
-        Args:
-            keys_to_fetch:
-                the keys to be fetched. server_name -> key_id -> min_valid_ts
+        self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+    ) -> Dict[str, FetchKeyResult]:
+        results = await self._queue.add_to_queue(
+            _FetchKeyRequest(
+                server_name=server_name,
+                key_ids=key_ids,
+                minimum_valid_until_ts=minimum_valid_until_ts,
+            )
+        )
+        return results.get(server_name, {})
 
-        Returns:
-            Map from server_name -> key_id -> FetchKeyResult
-        """
-        raise NotImplementedError
+    @abc.abstractmethod
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_FetchKeyRequest]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
+        pass
 
 
 class StoreKeyFetcher(KeyFetcher):
     """KeyFetcher impl which fetches keys from our data store"""
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        super().__init__(hs)
 
-    async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
-    ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """see KeyFetcher.get_keys"""
+        self.store = hs.get_datastore()
 
+    async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]):
         key_ids_to_fetch = (
-            (server_name, key_id)
-            for server_name, keys_for_server in keys_to_fetch.items()
-            for key_id in keys_for_server.keys()
+            (queue_value.server_name, key_id)
+            for queue_value in keys_to_fetch
+            for key_id in queue_value.key_ids
         )
 
         res = await self.store.get_server_verify_keys(key_ids_to_fetch)
@@ -578,6 +485,8 @@ class StoreKeyFetcher(KeyFetcher):
 
 class BaseV2KeyFetcher(KeyFetcher):
     def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
         self.store = hs.get_datastore()
         self.config = hs.config
 
@@ -685,10 +594,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         self.client = hs.get_federation_http_client()
         self.key_servers = self.config.key_servers
 
-    async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_FetchKeyRequest]
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """see KeyFetcher.get_keys"""
+        """see KeyFetcher._fetch_keys"""
 
         async def get_key(key_server: TrustedKeyServer) -> Dict:
             try:
@@ -724,12 +633,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         return union_of_keys
 
     async def get_server_verify_key_v2_indirect(
-        self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+        self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
             keys_to_fetch:
-                the keys to be fetched. server_name -> key_id -> min_valid_ts
+                the keys to be fetched.
 
             key_server: notary server to query for the keys
 
@@ -743,7 +652,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         perspective_name = key_server.server_name
         logger.info(
             "Requesting keys %s from notary server %s",
-            keys_to_fetch.items(),
+            keys_to_fetch,
             perspective_name,
         )
 
@@ -753,11 +662,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 path="/_matrix/key/v2/query",
                 data={
                     "server_keys": {
-                        server_name: {
-                            key_id: {"minimum_valid_until_ts": min_valid_ts}
-                            for key_id, min_valid_ts in server_keys.items()
+                        queue_value.server_name: {
+                            key_id: {
+                                "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
+                            }
+                            for key_id in queue_value.key_ids
                         }
-                        for server_name, server_keys in keys_to_fetch.items()
+                        for queue_value in keys_to_fetch
                     }
                 },
             )
@@ -858,7 +769,20 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         self.client = hs.get_federation_http_client()
 
     async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
+        self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+    ) -> Dict[str, FetchKeyResult]:
+        results = await self._queue.add_to_queue(
+            _FetchKeyRequest(
+                server_name=server_name,
+                key_ids=key_ids,
+                minimum_valid_until_ts=minimum_valid_until_ts,
+            ),
+            key=server_name,
+        )
+        return results.get(server_name, {})
+
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_FetchKeyRequest]
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
@@ -871,8 +795,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
 
         results = {}
 
-        async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
-            server_name, key_ids = key_to_fetch_item
+        async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None:
+            server_name = key_to_fetch_item.server_name
+            key_ids = key_to_fetch_item.key_ids
+
             try:
                 keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
                 results[server_name] = keys
@@ -883,7 +809,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             except Exception:
                 logger.exception("Error getting keys %s from %s", key_ids, server_name)
 
-        await yieldable_gather_results(get_key, keys_to_fetch.items())
+        await yieldable_gather_results(get_key, keys_to_fetch)
         return results
 
     async def get_server_verify_key_v2_direct(
@@ -955,37 +881,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             keys.update(response_keys)
 
         return keys
-
-
-async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
-    """Waits for the key to become available, and then performs a verification
-
-    Args:
-        verify_request:
-
-    Raises:
-        SynapseError if there was a problem performing the verification
-    """
-    server_name = verify_request.server_name
-    with PreserveLoggingContext():
-        _, key_id, verify_key = await verify_request.key_ready
-
-    json_object = verify_request.get_json_object()
-
-    try:
-        verify_signed_json(json_object, server_name, verify_key)
-    except SignatureVerifyException as e:
-        logger.debug(
-            "Error verifying signature for %s:%s:%s with key %s: %s",
-            server_name,
-            verify_key.alg,
-            verify_key.version,
-            encode_verify_key_base64(verify_key),
-            str(e),
-        )
-        raise SynapseError(
-            401,
-            "Invalid signature for server %s with key %s:%s: %s"
-            % (server_name, verify_key.alg, verify_key.version, str(e)),
-            Codes.UNAUTHORIZED,
-        )
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 59e0a434dc..5756fcb551 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -37,6 +37,7 @@ from synapse.http.servlet import (
 )
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
+    SynapseTags,
     start_active_span,
     start_active_span_from_request,
     tags,
@@ -151,7 +152,9 @@ class Authenticator:
             )
 
         await self.keyring.verify_json_for_server(
-            origin, json_request, now, "Incoming request"
+            origin,
+            json_request,
+            now,
         )
 
         logger.debug("Request from %s", origin)
@@ -314,7 +317,7 @@ class BaseFederationServlet:
                 raise
 
             request_tags = {
-                "request_id": request.get_request_id(),
+                SynapseTags.REQUEST_ID: request.get_request_id(),
                 tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
                 tags.HTTP_METHOD: request.get_method(),
                 tags.HTTP_URL: request.get_redacted_uri(),
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index d2fc8be5f5..ff8372c4e9 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -108,7 +108,9 @@ class GroupAttestationSigning:
 
         assert server_name is not None
         await self.keyring.verify_json_for_server(
-            server_name, attestation, now, "Group attestation"
+            server_name,
+            attestation,
+            now,
         )
 
     def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bf11315251..49ed7cabcc 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -577,7 +577,9 @@ class FederationHandler(BaseHandler):
 
         # Fetch the state events from the DB, and check we have the auth events.
         event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
-        auth_events_in_store = await self.store.have_seen_events(auth_event_ids)
+        auth_events_in_store = await self.store.have_seen_events(
+            room_id, auth_event_ids
+        )
 
         # Check for missing events. We handle state and auth event seperately,
         # as we want to pull the state from the DB, but we don't for the auth
@@ -610,7 +612,7 @@ class FederationHandler(BaseHandler):
 
             if missing_auth_events:
                 auth_events_in_store = await self.store.have_seen_events(
-                    missing_auth_events
+                    room_id, missing_auth_events
                 )
                 missing_auth_events.difference_update(auth_events_in_store)
 
@@ -710,7 +712,7 @@ class FederationHandler(BaseHandler):
 
         missing_auth_events = set(auth_event_ids) - fetched_events.keys()
         missing_auth_events.difference_update(
-            await self.store.have_seen_events(missing_auth_events)
+            await self.store.have_seen_events(room_id, missing_auth_events)
         )
         logger.debug("We are also missing %i auth events", len(missing_auth_events))
 
@@ -2475,7 +2477,7 @@ class FederationHandler(BaseHandler):
         #
         # we start by checking if they are in the store, and then try calling /event_auth/.
         if missing_auth:
-            have_events = await self.store.have_seen_events(missing_auth)
+            have_events = await self.store.have_seen_events(event.room_id, missing_auth)
             logger.debug("Events %s are in the store", have_events)
             missing_auth.difference_update(have_events)
 
@@ -2494,7 +2496,7 @@ class FederationHandler(BaseHandler):
                     return context
 
                 seen_remotes = await self.store.have_seen_events(
-                    [e.event_id for e in remote_auth_chain]
+                    event.room_id, [e.event_id for e in remote_auth_chain]
                 )
 
                 for e in remote_auth_chain:
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index abd9ddecca..046dba6fd8 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -26,7 +26,6 @@ from synapse.api.constants import (
     HistoryVisibility,
     Membership,
 )
-from synapse.api.errors import AuthError
 from synapse.events import EventBase
 from synapse.events.utils import format_event_for_client_v2
 from synapse.types import JsonDict
@@ -456,16 +455,16 @@ class SpaceSummaryHandler:
                     return True
 
             # Otherwise, check if they should be allowed access via membership in a space.
-            try:
-                await self._event_auth_handler.check_restricted_join_rules(
-                    state_ids, room_version, requester, member_event
+            if self._event_auth_handler.has_restricted_join_rules(
+                state_ids, room_version
+            ):
+                allowed_spaces = (
+                    await self._event_auth_handler.get_spaces_that_allow_join(state_ids)
                 )
-            except AuthError:
-                # The user doesn't have access due to spaces, but might have access
-                # another way. Keep trying.
-                pass
-            else:
-                return True
+                if await self._event_auth_handler.is_user_in_rooms(
+                    allowed_spaces, requester
+                ):
+                    return True
 
         # If this is a request over federation, check if the host is in the room or
         # is in one of the spaces specified via the join rules.
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 02ae36d2de..e607527ad1 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -464,7 +464,7 @@ class SyncHandler:
                 # ensure that we always include current state in the timeline
                 current_state_ids = frozenset()  # type: FrozenSet[str]
                 if any(e.is_state() for e in recents):
-                    current_state_ids_map = await self.state.get_current_state_ids(
+                    current_state_ids_map = await self.store.get_current_state_ids(
                         room_id
                     )
                     current_state_ids = frozenset(current_state_ids_map.values())
@@ -524,7 +524,7 @@ class SyncHandler:
                 # ensure that we always include current state in the timeline
                 current_state_ids = frozenset()
                 if any(e.is_state() for e in loaded_recents):
-                    current_state_ids_map = await self.state.get_current_state_ids(
+                    current_state_ids_map = await self.store.get_current_state_ids(
                         room_id
                     )
                     current_state_ids = frozenset(current_state_ids_map.values())
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 31897546a9..3f4f2411fc 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -15,6 +15,9 @@
 """ This module contains base REST classes for constructing REST servlets. """
 
 import logging
+from typing import Iterable, List, Optional, Union, overload
+
+from typing_extensions import Literal
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.util import json_decoder
@@ -107,12 +110,11 @@ def parse_boolean_from_args(args, name, default=None, required=False):
 
 def parse_string(
     request,
-    name,
-    default=None,
-    required=False,
-    allowed_values=None,
-    param_type="string",
-    encoding="ascii",
+    name: Union[bytes, str],
+    default: Optional[str] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
 ):
     """
     Parse a string parameter from the request query string.
@@ -122,18 +124,17 @@ def parse_string(
 
     Args:
         request: the twisted HTTP request.
-        name (bytes|unicode): the name of the query parameter.
-        default (bytes|unicode|None): value to use if the parameter is absent,
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
             defaults to None. Must be bytes if encoding is None.
-        required (bool): whether to raise a 400 SynapseError if the
+        required: whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
-        allowed_values (list[bytes|unicode]): List of allowed values for the
+        allowed_values: List of allowed values for the
             string, or None if any value is allowed, defaults to None. Must be
             the same type as name, if given.
-        encoding (str|None): The encoding to decode the string content with.
-
+        encoding : The encoding to decode the string content with.
     Returns:
-        bytes/unicode|None: A string value or the default. Unicode if encoding
+        A string value or the default. Unicode if encoding
         was given, bytes otherwise.
 
     Raises:
@@ -142,45 +143,105 @@ def parse_string(
             is not one of those allowed values.
     """
     return parse_string_from_args(
-        request.args, name, default, required, allowed_values, param_type, encoding
+        request.args, name, default, required, allowed_values, encoding
     )
 
 
-def parse_string_from_args(
-    args,
-    name,
-    default=None,
-    required=False,
-    allowed_values=None,
-    param_type="string",
-    encoding="ascii",
-):
+def _parse_string_value(
+    value: Union[str, bytes],
+    allowed_values: Optional[Iterable[str]],
+    name: str,
+    encoding: Optional[str],
+) -> Union[str, bytes]:
+    if encoding:
+        try:
+            value = value.decode(encoding)
+        except ValueError:
+            raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+
+    if allowed_values is not None and value not in allowed_values:
+        message = "Query parameter %r must be one of [%s]" % (
+            name,
+            ", ".join(repr(v) for v in allowed_values),
+        )
+        raise SynapseError(400, message)
+    else:
+        return value
+
+
+@overload
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Literal[None] = None,
+) -> Optional[List[bytes]]:
+    ...
+
+
+@overload
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> Optional[List[str]]:
+    ...
+
+
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
+) -> Optional[List[Union[bytes, str]]]:
+    """
+    Parse a string parameter from the request query string list.
+
+    If encoding is not None, the content of the query param will be
+    decoded to Unicode using the encoding, otherwise it will be encoded
+
+    Args:
+        args: the twisted HTTP request.args list.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
+        required : whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+        allowed_values (list[bytes|unicode]): List of allowed values for the
+            string, or None if any value is allowed, defaults to None. Must be
+            the same type as name, if given.
+        encoding: The encoding to decode the string content with.
+
+    Returns:
+        A string value or the default. Unicode if encoding
+        was given, bytes otherwise.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
 
     if not isinstance(name, bytes):
         name = name.encode("ascii")
 
     if name in args:
-        value = args[name][0]
-
-        if encoding:
-            try:
-                value = value.decode(encoding)
-            except ValueError:
-                raise SynapseError(
-                    400, "Query parameter %r must be %s" % (name, encoding)
-                )
-
-        if allowed_values is not None and value not in allowed_values:
-            message = "Query parameter %r must be one of [%s]" % (
-                name,
-                ", ".join(repr(v) for v in allowed_values),
-            )
-            raise SynapseError(400, message)
-        else:
-            return value
+        values = args[name]
+
+        return [
+            _parse_string_value(value, allowed_values, name=name, encoding=encoding)
+            for value in values
+        ]
     else:
         if required:
-            message = "Missing %s query parameter %r" % (param_type, name)
+            message = "Missing string query parameter %r" % (name)
             raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
 
@@ -190,6 +251,55 @@ def parse_string_from_args(
             return default
 
 
+def parse_string_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[str] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
+) -> Optional[Union[bytes, str]]:
+    """
+    Parse the string parameter from the request query string list
+    and return the first result.
+
+    If encoding is not None, the content of the query param will be
+    decoded to Unicode using the encoding, otherwise it will be encoded
+
+    Args:
+        args: the twisted HTTP request.args list.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
+        required: whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+        allowed_values: List of allowed values for the
+            string, or None if any value is allowed, defaults to None. Must be
+            the same type as name, if given.
+        encoding: The encoding to decode the string content with.
+
+    Returns:
+        A string value or the default. Unicode if encoding
+        was given, bytes otherwise.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
+
+    strings = parse_strings_from_args(
+        args,
+        name,
+        default=[default],
+        required=required,
+        allowed_values=allowed_values,
+        encoding=encoding,
+    )
+
+    return strings[0]
+
+
 def parse_json_value_from_request(request, allow_empty_body=False):
     """Parse a JSON value from the body of a twisted HTTP request.
 
@@ -215,7 +325,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
     try:
         content = json_decoder.decode(content_bytes.decode("utf-8"))
     except Exception as e:
-        logger.warning("Unable to parse JSON: %s", e)
+        logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
         raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
 
     return content
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index fba2fa3904..f64845b80c 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -265,6 +265,12 @@ class SynapseTags:
     # Whether the sync response has new data to be returned to the client.
     SYNC_RESULT = "sync.new_data"
 
+    # incoming HTTP request ID  (as written in the logs)
+    REQUEST_ID = "request_id"
+
+    # HTTP request tag (used to distinguish full vs incremental syncs, etc)
+    REQUEST_TAG = "request_tag"
+
 
 # Block everything by default
 # A regex which matches the server_names to expose traces for.
@@ -588,7 +594,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
 
     span = opentracing.tracer.active_span
     carrier = {}  # type: Dict[str, str]
-    opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
+    opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
         headers.addRawHeaders(key, value)
@@ -625,7 +631,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
     span = opentracing.tracer.active_span
 
     carrier = {}  # type: Dict[str, str]
-    opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
+    opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
         headers[key.encode()] = [value.encode()]
@@ -659,7 +665,7 @@ def inject_active_span_text_map(carrier, destination, check_destination=True):
         return
 
     opentracing.tracer.inject(
-        opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+        opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
     )
 
 
@@ -681,7 +687,7 @@ def get_active_span_text_map(destination=None):
 
     carrier = {}  # type: Dict[str, str]
     opentracing.tracer.inject(
-        opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+        opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
     )
 
     return carrier
@@ -696,7 +702,7 @@ def active_span_context_as_string():
     carrier = {}  # type: Dict[str, str]
     if opentracing:
         opentracing.tracer.inject(
-            opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+            opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
         )
     return json_encoder.encode(carrier)
 
@@ -824,7 +830,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
         return
 
     request_tags = {
-        "request_id": request.get_request_id(),
+        SynapseTags.REQUEST_ID: request.get_request_id(),
         tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
         tags.HTTP_METHOD: request.get_method(),
         tags.HTTP_URL: request.get_redacted_uri(),
@@ -833,9 +839,9 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
 
     request_name = request.request_metrics.name
     if extract_context:
-        scope = start_active_span_from_request(request, request_name, tags=request_tags)
+        scope = start_active_span_from_request(request, request_name)
     else:
-        scope = start_active_span(request_name, tags=request_tags)
+        scope = start_active_span(request_name)
 
     with scope:
         try:
@@ -845,4 +851,11 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
             # with JsonResource).
             scope.span.set_operation_name(request.request_metrics.name)
 
-            scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
+            # set the tags *after* the servlet completes, in case it decided to
+            # prioritise the span (tags will get dropped on unprioritised spans)
+            request_tags[
+                SynapseTags.REQUEST_TAG
+            ] = request.request_metrics.start_context.tag
+
+            for k, v in request_tags.items():
+                scope.span.set_tag(k, v)
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 714caf84c3..0d6d643d35 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -22,7 +22,11 @@ from prometheus_client.core import REGISTRY, Counter, Gauge
 from twisted.internet import defer
 
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
-from synapse.logging.opentracing import noop_context_manager, start_active_span
+from synapse.logging.opentracing import (
+    SynapseTags,
+    noop_context_manager,
+    start_active_span,
+)
 from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
@@ -202,7 +206,9 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
             try:
                 ctx = noop_context_manager()
                 if bg_start_span:
-                    ctx = start_active_span(desc, tags={"request_id": str(context)})
+                    ctx = start_active_span(
+                        desc, tags={SynapseTags.REQUEST_ID: str(context)}
+                    )
                 with ctx:
                     return await maybe_awaitable(func(*args, **kwargs))
             except Exception:
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index d6d55893af..70286b0ff7 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1061,15 +1061,15 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
     RoomTypingRestServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
     RoomSpaceSummaryRestServlet(hs).register(http_server)
+    RoomEventServlet(hs).register(http_server)
+    JoinedRoomsRestServlet(hs).register(http_server)
+    RoomAliasListServlet(hs).register(http_server)
+    SearchRestServlet(hs).register(http_server)
 
     # Some servlets only get registered for the main process.
     if not is_worker:
         RoomCreateRestServlet(hs).register(http_server)
         RoomForgetRestServlet(hs).register(http_server)
-        SearchRestServlet(hs).register(http_server)
-        JoinedRoomsRestServlet(hs).register(http_server)
-        RoomEventServlet(hs).register(http_server)
-        RoomAliasListServlet(hs).register(http_server)
 
 
 def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index 2c169abbf3..07ea39a8a3 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -16,11 +16,7 @@ import logging
 from http import HTTPStatus
 
 from synapse.api.errors import Codes, SynapseError
-from synapse.http.servlet import (
-    RestServlet,
-    assert_params_in_dict,
-    parse_json_object_from_request,
-)
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 
 from ._base import client_patterns
 
@@ -42,15 +38,14 @@ class ReportEventRestServlet(RestServlet):
         user_id = requester.user.to_string()
 
         body = parse_json_object_from_request(request)
-        assert_params_in_dict(body, ("reason", "score"))
 
-        if not isinstance(body["reason"], str):
+        if not isinstance(body.get("reason", ""), str):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Param 'reason' must be a string",
                 Codes.BAD_JSON,
             )
-        if not isinstance(body["score"], int):
+        if not isinstance(body.get("score", 0), int):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Param 'score' must be an integer",
@@ -61,7 +56,7 @@ class ReportEventRestServlet(RestServlet):
             room_id=room_id,
             event_id=event_id,
             user_id=user_id,
-            reason=body["reason"],
+            reason=body.get("reason"),
             content=body,
             received_ts=self.clock.time_msec(),
         )
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index aba1734a55..d56a1ae482 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
 from synapse.util import json_decoder
+from synapse.util.async_helpers import yieldable_gather_results
 
 logger = logging.getLogger(__name__)
 
@@ -210,7 +211,13 @@ class RemoteKey(DirectServeJsonResource):
         # If there is a cache miss, request the missing keys, then recurse (and
         # ensure the result is sent).
         if cache_misses and query_remote_on_cache_miss:
-            await self.fetcher.get_keys(cache_misses)
+            await yieldable_gather_results(
+                lambda t: self.fetcher.get_keys(*t),
+                (
+                    (server_name, list(keys), 0)
+                    for server_name, keys in cache_misses.items()
+                ),
+            )
             await self.query_keys(request, query, query_remote_on_cache_miss=False)
         else:
             signed_keys = []
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index f7872501a0..c57ae5ef15 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -168,6 +168,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         backfilled,
     ):
         self._invalidate_get_event_cache(event_id)
+        self.have_seen_event.invalidate((room_id, event_id))
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6963bbf7f4..403a5ddaba 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -22,6 +22,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Set,
     Tuple,
     overload,
 )
@@ -55,7 +56,7 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
@@ -1045,32 +1046,74 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
-    async def have_seen_events(self, event_ids):
+    async def have_seen_events(
+        self, room_id: str, event_ids: Iterable[str]
+    ) -> Set[str]:
         """Given a list of event ids, check if we have already processed them.
 
+        The room_id is only used to structure the cache (so that it can later be
+        invalidated by room_id) - there is no guarantee that the events are actually
+        in the room in question.
+
         Args:
-            event_ids (iterable[str]):
+            room_id: Room we are polling
+            event_ids: events we are looking for
 
         Returns:
             set[str]: The events we have already seen.
         """
+        res = await self._have_seen_events_dict(
+            (room_id, event_id) for event_id in event_ids
+        )
+        return {eid for ((_rid, eid), have_event) in res.items() if have_event}
+
+    @cachedList("have_seen_event", "keys")
+    async def _have_seen_events_dict(
+        self, keys: Iterable[Tuple[str, str]]
+    ) -> Dict[Tuple[str, str], bool]:
+        """Helper for have_seen_events
+
+        Returns:
+             a dict {(room_id, event_id)-> bool}
+        """
         # if the event cache contains the event, obviously we've seen it.
-        results = {x for x in event_ids if self._get_event_cache.contains(x)}
 
-        def have_seen_events_txn(txn, chunk):
-            sql = "SELECT event_id FROM events as e WHERE "
+        cache_results = {
+            (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
+        }
+        results = {x: True for x in cache_results}
+
+        def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+            # we deliberately do *not* query the database for room_id, to make the
+            # query an index-only lookup on `events_event_id_key`.
+            #
+            # We therefore pull the events from the database into a set...
+
+            sql = "SELECT event_id FROM events AS e WHERE "
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "e.event_id", chunk
+                txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk]
             )
             txn.execute(sql + clause, args)
-            results.update(row[0] for row in txn)
+            found_events = {eid for eid, in txn}
 
-        for chunk in batch_iter((x for x in event_ids if x not in results), 100):
+            # ... and then we can update the results for each row in the batch
+            results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk})
+
+        # each batch requires its own index scan, so we make the batches as big as
+        # possible.
+        for chunk in batch_iter((k for k in keys if k not in cache_results), 500):
             await self.db_pool.runInteraction(
                 "have_seen_events", have_seen_events_txn, chunk
             )
+
         return results
 
+    @cached(max_entries=100000, tree=True)
+    async def have_seen_event(self, room_id: str, event_id: str):
+        # this only exists for the benefit of the @cachedList descriptor on
+        # _have_seen_events_dict
+        raise NotImplementedError()
+
     def _get_current_state_event_counts_txn(self, txn, room_id):
         """
         See get_current_state_event_counts.
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 8f83748b5e..7fb7780d0f 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -16,14 +16,14 @@ import logging
 from typing import Any, List, Set, Tuple
 
 from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.state import StateGroupWorkerStore
 from synapse.types import RoomStreamToken
 
 logger = logging.getLogger(__name__)
 
 
-class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
+class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
     async def purge_history(
         self, room_id: str, token: str, delete_local_events: bool
     ) -> Set[int]:
@@ -203,8 +203,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
             "DELETE FROM event_to_state_groups "
             "WHERE event_id IN (SELECT event_id from events_to_purge)"
         )
-        for event_id, _ in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
 
         # Delete all remote non-state events
         for table in (
@@ -283,6 +281,20 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         # so make sure to keep this actually last.
         txn.execute("DROP TABLE events_to_purge")
 
+        for event_id, should_delete in event_rows:
+            self._invalidate_cache_and_stream(
+                txn, self._get_state_group_for_event, (event_id,)
+            )
+
+            # XXX: This is racy, since have_seen_events could be called between the
+            #    transaction completing and the invalidation running. On the other hand,
+            #    that's no different to calling `have_seen_events` just before the
+            #    event is deleted from the database.
+            if should_delete:
+                self._invalidate_cache_and_stream(
+                    txn, self.have_seen_event, (room_id, event_id)
+                )
+
         logger.info("[purge] done")
 
         return referenced_state_groups
@@ -422,7 +434,11 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         #       index on them. In any case we should be clearing out 'stream' tables
         #       periodically anyway (#5888)
 
-        # TODO: we could probably usefully do a bunch of cache invalidation here
+        # TODO: we could probably usefully do a bunch more cache invalidation here
+
+        # XXX: as with purge_history, this is racy, but no worse than other races
+        #   that already exist.
+        self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,))
 
         logger.info("[purge] done")
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 5f38634f48..0cf450f81d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         room_id: str,
         event_id: str,
         user_id: str,
-        reason: str,
+        reason: Optional[str],
         content: JsonDict,
         received_ts: int,
     ) -> None: