summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2021-06-02 11:38:54 -0400
committerPatrick Cloke <patrickc@matrix.org>2021-06-02 11:38:54 -0400
commit09361655d2fcdac24642efa2b60122dd4f0684be (patch)
treea8a245f6185f6102b47b6b35fbd33a47d8447571
parentMerge remote-tracking branch 'origin/release-v1.35' into matrix-org-hotfixes (diff)
parentRewrite the KeyRing (#10035) (diff)
downloadsynapse-09361655d2fcdac24642efa2b60122dd4f0684be.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
-rw-r--r--CHANGES.md11
-rw-r--r--changelog.d/10035.feature1
-rw-r--r--changelog.d/10048.misc1
-rw-r--r--changelog.d/10074.misc1
-rw-r--r--changelog.d/10077.feature1
-rw-r--r--changelog.d/10084.feature1
-rw-r--r--changelog.d/10091.misc1
-rw-r--r--changelog.d/10092.bugfix1
-rw-r--r--changelog.d/10102.misc1
-rw-r--r--changelog.d/10109.bugfix1
-rw-r--r--changelog.d/9953.feature1
-rw-r--r--changelog.d/9973.feature1
-rw-r--r--changelog.d/9973.misc1
-rw-r--r--debian/changelog6
-rw-r--r--docs/admin_api/event_reports.md4
-rw-r--r--docs/workers.md3
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py8
-rw-r--r--synapse/app/generic_worker.py4
-rw-r--r--synapse/crypto/keyring.py642
-rw-r--r--synapse/federation/transport/server.py7
-rw-r--r--synapse/groups/attestations.py4
-rw-r--r--synapse/handlers/federation.py12
-rw-r--r--synapse/handlers/space_summary.py19
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/http/servlet.py196
-rw-r--r--synapse/logging/opentracing.py31
-rw-r--r--synapse/metrics/background_process_metrics.py10
-rw-r--r--synapse/rest/client/v1/room.py8
-rw-r--r--synapse/rest/client/v2_alpha/report_event.py13
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py9
-rw-r--r--synapse/storage/databases/main/cache.py1
-rw-r--r--synapse/storage/databases/main/events_worker.py61
-rw-r--r--synapse/storage/databases/main/purge_events.py26
-rw-r--r--synapse/storage/databases/main/room.py2
-rw-r--r--tests/crypto/test_keyring.py170
-rw-r--r--tests/rest/admin/test_event_reports.py15
-rw-r--r--tests/rest/client/v2_alpha/test_report_event.py83
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py18
-rw-r--r--tests/storage/databases/__init__.py13
-rw-r--r--tests/storage/databases/main/__init__.py13
-rw-r--r--tests/storage/databases/main/test_events_worker.py96
-rw-r--r--tests/util/test_batching_queue.py37
43 files changed, 935 insertions, 605 deletions
diff --git a/CHANGES.md b/CHANGES.md
index 7e6f478d42..f03a53affc 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,5 +1,12 @@
-Synapse 1.35.0rc3 (2021-05-28)
-==============================
+Synapse 1.35.0 (2021-06-01)
+===========================
+
+Note that [the tag](https://github.com/matrix-org/synapse/releases/tag/v1.35.0rc3) and [docker images](https://hub.docker.com/layers/matrixdotorg/synapse/v1.35.0rc3/images/sha256-34ccc87bd99a17e2cbc0902e678b5937d16bdc1991ead097eee6096481ecf2c4?context=explore) for `v1.35.0rc3` were incorrectly built. If you are experiencing issues with either, it is recommended to upgrade to the equivalent tag or docker image for the `v1.35.0` release.
+
+Deprecations and Removals
+-------------------------
+
+- The core Synapse development team plan to drop support for the [unstable API of MSC2858](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2858-Multiple-SSO-Identity-Providers.md#unstable-prefix), including the undocumented `experimental.msc2858_enabled` config option, in August 2021. Client authors should ensure that their clients are updated to use the stable API (which has been supported since Synapse 1.30) well before that time, to give their users time to upgrade. ([\#10101](https://github.com/matrix-org/synapse/issues/10101))
 
 Bugfixes
 --------
diff --git a/changelog.d/10035.feature b/changelog.d/10035.feature
new file mode 100644
index 0000000000..68052b5a7e
--- /dev/null
+++ b/changelog.d/10035.feature
@@ -0,0 +1 @@
+Rewrite logic around verifying JSON object and fetching server keys to be more performant and use less memory.
diff --git a/changelog.d/10048.misc b/changelog.d/10048.misc
new file mode 100644
index 0000000000..a901f8431e
--- /dev/null
+++ b/changelog.d/10048.misc
@@ -0,0 +1 @@
+Add `parse_strings_from_args` for parsing an array from query parameters.
diff --git a/changelog.d/10074.misc b/changelog.d/10074.misc
new file mode 100644
index 0000000000..8dbe2cd2bc
--- /dev/null
+++ b/changelog.d/10074.misc
@@ -0,0 +1 @@
+Update opentracing to inject the right context into the carrier.
diff --git a/changelog.d/10077.feature b/changelog.d/10077.feature
new file mode 100644
index 0000000000..808feb2215
--- /dev/null
+++ b/changelog.d/10077.feature
@@ -0,0 +1 @@
+Make reason and score parameters optional for reporting content. Implements [MSC2414](https://github.com/matrix-org/matrix-doc/pull/2414). Contributed by Callum Brown.
diff --git a/changelog.d/10084.feature b/changelog.d/10084.feature
new file mode 100644
index 0000000000..602cb6ff51
--- /dev/null
+++ b/changelog.d/10084.feature
@@ -0,0 +1 @@
+Add support for routing more requests to workers.
diff --git a/changelog.d/10091.misc b/changelog.d/10091.misc
new file mode 100644
index 0000000000..dbe310fd17
--- /dev/null
+++ b/changelog.d/10091.misc
@@ -0,0 +1 @@
+Log method and path when dropping request due to size limit.
diff --git a/changelog.d/10092.bugfix b/changelog.d/10092.bugfix
new file mode 100644
index 0000000000..09b2aba7ff
--- /dev/null
+++ b/changelog.d/10092.bugfix
@@ -0,0 +1 @@
+Fix a bug in the `force_tracing_for_users` option introduced in Synapse v1.35 which meant that the OpenTracing spans produced were missing most tags.
diff --git a/changelog.d/10102.misc b/changelog.d/10102.misc
new file mode 100644
index 0000000000..87672ee295
--- /dev/null
+++ b/changelog.d/10102.misc
@@ -0,0 +1 @@
+Make `/sync` do fewer state resolutions.
diff --git a/changelog.d/10109.bugfix b/changelog.d/10109.bugfix
new file mode 100644
index 0000000000..bc41bf9e5e
--- /dev/null
+++ b/changelog.d/10109.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in v1.35.0 where invite-only rooms would be shown to users in a space who were not invited.
diff --git a/changelog.d/9953.feature b/changelog.d/9953.feature
new file mode 100644
index 0000000000..6b3d1adc70
--- /dev/null
+++ b/changelog.d/9953.feature
@@ -0,0 +1 @@
+Improve performance of incoming federation transactions in large rooms.
diff --git a/changelog.d/9973.feature b/changelog.d/9973.feature
new file mode 100644
index 0000000000..6b3d1adc70
--- /dev/null
+++ b/changelog.d/9973.feature
@@ -0,0 +1 @@
+Improve performance of incoming federation transactions in large rooms.
diff --git a/changelog.d/9973.misc b/changelog.d/9973.misc
deleted file mode 100644
index 7f22d42291..0000000000
--- a/changelog.d/9973.misc
+++ /dev/null
@@ -1 +0,0 @@
-Make `LruCache.invalidate` support tree invalidation, and remove `invalidate_many`.
diff --git a/debian/changelog b/debian/changelog
index bf99ae772c..d5efb8ccba 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+matrix-synapse-py3 (1.35.0) stable; urgency=medium
+
+  * New synapse release 1.35.0.
+
+ -- Synapse Packaging team <packages@matrix.org>  Tue, 01 Jun 2021 13:23:35 +0100
+
 matrix-synapse-py3 (1.34.0) stable; urgency=medium
 
   * New synapse release 1.34.0.
diff --git a/docs/admin_api/event_reports.md b/docs/admin_api/event_reports.md
index 0159098138..bfec06f755 100644
--- a/docs/admin_api/event_reports.md
+++ b/docs/admin_api/event_reports.md
@@ -75,9 +75,9 @@ The following fields are returned in the JSON response body:
 * `name`: string - The name of the room.
 * `event_id`: string - The ID of the reported event.
 * `user_id`: string - This is the user who reported the event and wrote the reason.
-* `reason`: string - Comment made by the `user_id` in this report. May be blank.
+* `reason`: string - Comment made by the `user_id` in this report. May be blank or `null`.
 * `score`: integer - Content is reported based upon a negative score, where -100 is
-  "most offensive" and 0 is "inoffensive".
+  "most offensive" and 0 is "inoffensive". May be `null`.
 * `sender`: string - This is the ID of the user who sent the original message/event that
   was reported.
 * `canonical_alias`: string - The canonical alias of the room. `null` if the room does not
diff --git a/docs/workers.md b/docs/workers.md
index c6282165b0..46b5e4b737 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -228,6 +228,9 @@ expressions:
     ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
     ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
     ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
+    ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/event/
+    ^/_matrix/client/(api/v1|r0|unstable)/joined_rooms$
+    ^/_matrix/client/(api/v1|r0|unstable)/search$
 
     # Registration/login requests
     ^/_matrix/client/(api/v1|r0|unstable)/login$
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 4591246bd1..d9843a1708 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.35.0rc3"
+__version__ = "1.35.0"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 458306eba5..26a3b38918 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -206,11 +206,11 @@ class Auth:
                 requester = create_requester(user_id, app_service=app_service)
 
                 request.requester = user_id
+                if user_id in self._force_tracing_for_users:
+                    opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
                 opentracing.set_tag("authenticated_entity", user_id)
                 opentracing.set_tag("user_id", user_id)
                 opentracing.set_tag("appservice_id", app_service.id)
-                if user_id in self._force_tracing_for_users:
-                    opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
 
                 return requester
 
@@ -259,12 +259,12 @@ class Auth:
             )
 
             request.requester = requester
+            if user_info.token_owner in self._force_tracing_for_users:
+                opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
             opentracing.set_tag("authenticated_entity", user_info.token_owner)
             opentracing.set_tag("user_id", user_info.user_id)
             if device_id:
                 opentracing.set_tag("device_id", device_id)
-            if user_info.token_owner in self._force_tracing_for_users:
-                opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
 
             return requester
         except KeyError:
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 91ad326f19..57c2fc2e88 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -109,7 +109,7 @@ from synapse.storage.databases.main.monthly_active_users import (
     MonthlyActiveUsersWorkerStore,
 )
 from synapse.storage.databases.main.presence import PresenceStore
-from synapse.storage.databases.main.search import SearchWorkerStore
+from synapse.storage.databases.main.search import SearchStore
 from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.databases.main.transactions import TransactionWorkerStore
 from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
@@ -242,7 +242,7 @@ class GenericWorkerSlavedStore(
     MonthlyActiveUsersWorkerStore,
     MediaRepositoryStore,
     ServerMetricsStore,
-    SearchWorkerStore,
+    SearchStore,
     TransactionWorkerStore,
     BaseSlavedStore,
 ):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 6fc0712978..c840ffca71 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -16,8 +16,7 @@
 import abc
 import logging
 import urllib
-from collections import defaultdict
-from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from signedjson.key import (
@@ -44,17 +43,12 @@ from synapse.api.errors import (
 from synapse.config.key import TrustedKeyServer
 from synapse.events import EventBase
 from synapse.events.utils import prune_event_dict
-from synapse.logging.context import (
-    PreserveLoggingContext,
-    make_deferred_yieldable,
-    preserve_fn,
-    run_in_background,
-)
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage.keys import FetchKeyResult
 from synapse.types import JsonDict
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import yieldable_gather_results
-from synapse.util.metrics import Measure
+from synapse.util.batching_queue import BatchingQueue
 from synapse.util.retryutils import NotRetryingDestination
 
 if TYPE_CHECKING:
@@ -80,32 +74,19 @@ class VerifyJsonRequest:
         minimum_valid_until_ts: time at which we require the signing key to
             be valid. (0 implies we don't care)
 
-        request_name: The name of the request.
-
         key_ids: The set of key_ids to that could be used to verify the JSON object
-
-        key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
-            A deferred (server_name, key_id, verify_key) tuple that resolves when
-            a verify key has been fetched. The deferreds' callbacks are run with no
-            logcontext.
-
-            If we are unable to find a key which satisfies the request, the deferred
-            errbacks with an M_UNAUTHORIZED SynapseError.
     """
 
     server_name = attr.ib(type=str)
     get_json_object = attr.ib(type=Callable[[], JsonDict])
     minimum_valid_until_ts = attr.ib(type=int)
-    request_name = attr.ib(type=str)
     key_ids = attr.ib(type=List[str])
-    key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
 
     @staticmethod
     def from_json_object(
         server_name: str,
         json_object: JsonDict,
         minimum_valid_until_ms: int,
-        request_name: str,
     ):
         """Create a VerifyJsonRequest to verify all signatures on a signed JSON
         object for the given server.
@@ -115,7 +96,6 @@ class VerifyJsonRequest:
             server_name,
             lambda: json_object,
             minimum_valid_until_ms,
-            request_name=request_name,
             key_ids=key_ids,
         )
 
@@ -135,16 +115,48 @@ class VerifyJsonRequest:
             # memory than the Event object itself.
             lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
             minimum_valid_until_ms,
-            request_name=event.event_id,
             key_ids=key_ids,
         )
 
+    def to_fetch_key_request(self) -> "_FetchKeyRequest":
+        """Create a key fetch request for all keys needed to satisfy the
+        verification request.
+        """
+        return _FetchKeyRequest(
+            server_name=self.server_name,
+            minimum_valid_until_ts=self.minimum_valid_until_ts,
+            key_ids=self.key_ids,
+        )
+
 
 class KeyLookupError(ValueError):
     pass
 
 
+@attr.s(slots=True)
+class _FetchKeyRequest:
+    """A request for keys for a given server.
+
+    We will continue to try and fetch until we have all the keys listed under
+    `key_ids` (with an appropriate `valid_until_ts` property) or we run out of
+    places to fetch keys from.
+
+    Attributes:
+        server_name: The name of the server that owns the keys.
+        minimum_valid_until_ts: The timestamp which the keys must be valid until.
+        key_ids: The IDs of the keys to attempt to fetch
+    """
+
+    server_name = attr.ib(type=str)
+    minimum_valid_until_ts = attr.ib(type=int)
+    key_ids = attr.ib(type=List[str])
+
+
 class Keyring:
+    """Handles verifying signed JSON objects and fetching the keys needed to do
+    so.
+    """
+
     def __init__(
         self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
     ):
@@ -158,22 +170,22 @@ class Keyring:
             )
         self._key_fetchers = key_fetchers
 
-        # map from server name to Deferred. Has an entry for each server with
-        # an ongoing key download; the Deferred completes once the download
-        # completes.
-        #
-        # These are regular, logcontext-agnostic Deferreds.
-        self.key_downloads = {}  # type: Dict[str, defer.Deferred]
+        self._server_queue = BatchingQueue(
+            "keyring_server",
+            clock=hs.get_clock(),
+            process_batch_callback=self._inner_fetch_key_requests,
+        )  # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]]
 
-    def verify_json_for_server(
+    async def verify_json_for_server(
         self,
         server_name: str,
         json_object: JsonDict,
         validity_time: int,
-        request_name: str,
-    ) -> defer.Deferred:
+    ) -> None:
         """Verify that a JSON object has been signed by a given server
 
+        Completes if the the object was correctly signed, otherwise raises.
+
         Args:
             server_name: name of the server which must have signed this object
 
@@ -181,52 +193,45 @@ class Keyring:
 
             validity_time: timestamp at which we require the signing key to
                 be valid. (0 implies we don't care)
-
-            request_name: an identifier for this json object (eg, an event id)
-                for logging.
-
-        Returns:
-            Deferred[None]: completes if the the object was correctly signed, otherwise
-                errbacks with an error
         """
         request = VerifyJsonRequest.from_json_object(
             server_name,
             json_object,
             validity_time,
-            request_name,
         )
-        requests = (request,)
-        return make_deferred_yieldable(self._verify_objects(requests)[0])
+        return await self.process_request(request)
 
     def verify_json_objects_for_server(
-        self, server_and_json: Iterable[Tuple[str, dict, int, str]]
+        self, server_and_json: Iterable[Tuple[str, dict, int]]
     ) -> List[defer.Deferred]:
         """Bulk verifies signatures of json objects, bulk fetching keys as
         necessary.
 
         Args:
             server_and_json:
-                Iterable of (server_name, json_object, validity_time, request_name)
+                Iterable of (server_name, json_object, validity_time)
                 tuples.
 
                 validity_time is a timestamp at which the signing key must be
                 valid.
 
-                request_name is an identifier for this json object (eg, an event id)
-                for logging.
-
         Returns:
             List<Deferred[None]>: for each input triplet, a deferred indicating success
                 or failure to verify each json object's signature for the given
                 server_name. The deferreds run their callbacks in the sentinel
                 logcontext.
         """
-        return self._verify_objects(
-            VerifyJsonRequest.from_json_object(
-                server_name, json_object, validity_time, request_name
+        return [
+            run_in_background(
+                self.process_request,
+                VerifyJsonRequest.from_json_object(
+                    server_name,
+                    json_object,
+                    validity_time,
+                ),
             )
-            for server_name, json_object, validity_time, request_name in server_and_json
-        )
+            for server_name, json_object, validity_time in server_and_json
+        ]
 
     def verify_events_for_server(
         self, server_and_events: Iterable[Tuple[str, EventBase, int]]
@@ -252,321 +257,223 @@ class Keyring:
                 server_name. The deferreds run their callbacks in the sentinel
                 logcontext.
         """
-        return self._verify_objects(
-            VerifyJsonRequest.from_event(server_name, event, validity_time)
+        return [
+            run_in_background(
+                self.process_request,
+                VerifyJsonRequest.from_event(
+                    server_name,
+                    event,
+                    validity_time,
+                ),
+            )
             for server_name, event, validity_time in server_and_events
-        )
-
-    def _verify_objects(
-        self, verify_requests: Iterable[VerifyJsonRequest]
-    ) -> List[defer.Deferred]:
-        """Does the work of verify_json_[objects_]for_server
-
-
-        Args:
-            verify_requests: Iterable of verification requests.
+        ]
 
-        Returns:
-            List<Deferred[None]>: for each input item, a deferred indicating success
-                or failure to verify each json object's signature for the given
-                server_name. The deferreds run their callbacks in the sentinel
-                logcontext.
+    async def process_request(self, verify_request: VerifyJsonRequest) -> None:
+        """Processes the `VerifyJsonRequest`. Raises if the object is not signed
+        by the server, the signatures don't match or we failed to fetch the
+        necessary keys.
         """
-        # a list of VerifyJsonRequests which are awaiting a key lookup
-        key_lookups = []
-        handle = preserve_fn(_handle_key_deferred)
-
-        def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
-            """Process an entry in the request list
-
-            Adds a key request to key_lookups, and returns a deferred which
-            will complete or fail (in the sentinel context) when verification completes.
-            """
-            if not verify_request.key_ids:
-                return defer.fail(
-                    SynapseError(
-                        400,
-                        "Not signed by %s" % (verify_request.server_name,),
-                        Codes.UNAUTHORIZED,
-                    )
-                )
 
-            logger.debug(
-                "Verifying %s for %s with key_ids %s, min_validity %i",
-                verify_request.request_name,
-                verify_request.server_name,
-                verify_request.key_ids,
-                verify_request.minimum_valid_until_ts,
+        if not verify_request.key_ids:
+            raise SynapseError(
+                400,
+                f"Not signed by {verify_request.server_name}",
+                Codes.UNAUTHORIZED,
             )
 
-            # add the key request to the queue, but don't start it off yet.
-            key_lookups.append(verify_request)
-
-            # now run _handle_key_deferred, which will wait for the key request
-            # to complete and then do the verification.
-            #
-            # We want _handle_key_request to log to the right context, so we
-            # wrap it with preserve_fn (aka run_in_background)
-            return handle(verify_request)
-
-        results = [process(r) for r in verify_requests]
-
-        if key_lookups:
-            run_in_background(self._start_key_lookups, key_lookups)
-
-        return results
-
-    async def _start_key_lookups(
-        self, verify_requests: List[VerifyJsonRequest]
-    ) -> None:
-        """Sets off the key fetches for each verify request
-
-        Once each fetch completes, verify_request.key_ready will be resolved.
-
-        Args:
-            verify_requests:
-        """
-
-        try:
-            # map from server name to a set of outstanding request ids
-            server_to_request_ids = {}  # type: Dict[str, Set[int]]
-
-            for verify_request in verify_requests:
-                server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids.setdefault(server_name, set()).add(request_id)
-
-            # Wait for any previous lookups to complete before proceeding.
-            await self.wait_for_previous_lookups(server_to_request_ids.keys())
-
-            # take out a lock on each of the servers by sticking a Deferred in
-            # key_downloads
-            for server_name in server_to_request_ids.keys():
-                self.key_downloads[server_name] = defer.Deferred()
-                logger.debug("Got key lookup lock on %s", server_name)
-
-            # When we've finished fetching all the keys for a given server_name,
-            # drop the lock by resolving the deferred in key_downloads.
-            def drop_server_lock(server_name):
-                d = self.key_downloads.pop(server_name)
-                d.callback(None)
-
-            def lookup_done(res, verify_request):
-                server_name = verify_request.server_name
-                server_requests = server_to_request_ids[server_name]
-                server_requests.remove(id(verify_request))
-
-                # if there are no more requests for this server, we can drop the lock.
-                if not server_requests:
-                    logger.debug("Releasing key lookup lock on %s", server_name)
-                    drop_server_lock(server_name)
-
-                return res
+        # Add the keys we need to verify to the queue for retrieval. We queue
+        # up requests for the same server so we don't end up with many in flight
+        # requests for the same keys.
+        key_request = verify_request.to_fetch_key_request()
+        found_keys_by_server = await self._server_queue.add_to_queue(
+            key_request, key=verify_request.server_name
+        )
 
-            for verify_request in verify_requests:
-                verify_request.key_ready.addBoth(lookup_done, verify_request)
+        # Since we batch up requests the returned set of keys may contain keys
+        # from other servers, so we pull out only the ones we care about.s
+        found_keys = found_keys_by_server.get(verify_request.server_name, {})
 
-            # Actually start fetching keys.
-            self._get_server_verify_keys(verify_requests)
-        except Exception:
-            logger.exception("Error starting key lookups")
+        # Verify each signature we got valid keys for, raising if we can't
+        # verify any of them.
+        verified = False
+        for key_id in verify_request.key_ids:
+            key_result = found_keys.get(key_id)
+            if not key_result:
+                continue
 
-    async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
-        """Waits for any previous key lookups for the given servers to finish.
+            if key_result.valid_until_ts < verify_request.minimum_valid_until_ts:
+                continue
 
-        Args:
-            server_names: list of servers which we want to look up
+            verify_key = key_result.verify_key
+            json_object = verify_request.get_json_object()
+            try:
+                verify_signed_json(
+                    json_object,
+                    verify_request.server_name,
+                    verify_key,
+                )
+                verified = True
+            except SignatureVerifyException as e:
+                logger.debug(
+                    "Error verifying signature for %s:%s:%s with key %s: %s",
+                    verify_request.server_name,
+                    verify_key.alg,
+                    verify_key.version,
+                    encode_verify_key_base64(verify_key),
+                    str(e),
+                )
+                raise SynapseError(
+                    401,
+                    "Invalid signature for server %s with key %s:%s: %s"
+                    % (
+                        verify_request.server_name,
+                        verify_key.alg,
+                        verify_key.version,
+                        str(e),
+                    ),
+                    Codes.UNAUTHORIZED,
+                )
 
-        Returns:
-            Resolves once all key lookups for the given servers have
-                completed. Follows the synapse rules of logcontext preservation.
-        """
-        loop_count = 1
-        while True:
-            wait_on = [
-                (server_name, self.key_downloads[server_name])
-                for server_name in server_names
-                if server_name in self.key_downloads
-            ]
-            if not wait_on:
-                break
-            logger.info(
-                "Waiting for existing lookups for %s to complete [loop %i]",
-                [w[0] for w in wait_on],
-                loop_count,
+        if not verified:
+            raise SynapseError(
+                401,
+                f"Failed to find any key to satisfy: {key_request}",
+                Codes.UNAUTHORIZED,
             )
-            with PreserveLoggingContext():
-                await defer.DeferredList((w[1] for w in wait_on))
 
-            loop_count += 1
+    async def _inner_fetch_key_requests(
+        self, requests: List[_FetchKeyRequest]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
+        """Processing function for the queue of `_FetchKeyRequest`."""
+
+        logger.debug("Starting fetch for %s", requests)
+
+        # First we need to deduplicate requests for the same key. We do this by
+        # taking the *maximum* requested `minimum_valid_until_ts` for each pair
+        # of server name/key ID.
+        server_to_key_to_ts = {}  # type: Dict[str, Dict[str, int]]
+        for request in requests:
+            by_server = server_to_key_to_ts.setdefault(request.server_name, {})
+            for key_id in request.key_ids:
+                existing_ts = by_server.get(key_id, 0)
+                by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts)
+
+        deduped_requests = [
+            _FetchKeyRequest(server_name, minimum_valid_ts, [key_id])
+            for server_name, by_server in server_to_key_to_ts.items()
+            for key_id, minimum_valid_ts in by_server.items()
+        ]
+
+        logger.debug("Deduplicated key requests to %s", deduped_requests)
+
+        # For each key we call `_inner_verify_request` which will handle
+        # fetching each key. Note these shouldn't throw if we fail to contact
+        # other servers etc.
+        results_per_request = await yieldable_gather_results(
+            self._inner_fetch_key_request,
+            deduped_requests,
+        )
 
-    def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
-        """Tries to find at least one key for each verify request
+        # We now convert the returned list of results into a map from server
+        # name to key ID to FetchKeyResult, to return.
+        to_return = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
+        for (request, results) in zip(deduped_requests, results_per_request):
+            to_return_by_server = to_return.setdefault(request.server_name, {})
+            for key_id, key_result in results.items():
+                existing = to_return_by_server.get(key_id)
+                if not existing or existing.valid_until_ts < key_result.valid_until_ts:
+                    to_return_by_server[key_id] = key_result
 
-        For each verify_request, verify_request.key_ready is called back with
-        params (server_name, key_id, VerifyKey) if a key is found, or errbacked
-        with a SynapseError if none of the keys are found.
+        return to_return
 
-        Args:
-            verify_requests: list of verify requests
+    async def _inner_fetch_key_request(
+        self, verify_request: _FetchKeyRequest
+    ) -> Dict[str, FetchKeyResult]:
+        """Attempt to fetch the given key by calling each key fetcher one by
+        one.
         """
+        logger.debug("Starting fetch for %s", verify_request)
 
-        remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
+        found_keys: Dict[str, FetchKeyResult] = {}
+        missing_key_ids = set(verify_request.key_ids)
 
-        async def do_iterations():
-            try:
-                with Measure(self.clock, "get_server_verify_keys"):
-                    for f in self._key_fetchers:
-                        if not remaining_requests:
-                            return
-                        await self._attempt_key_fetches_with_fetcher(
-                            f, remaining_requests
-                        )
-
-                    # look for any requests which weren't satisfied
-                    while remaining_requests:
-                        verify_request = remaining_requests.pop()
-                        rq_str = (
-                            "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
-                            % (
-                                verify_request.server_name,
-                                verify_request.key_ids,
-                                verify_request.minimum_valid_until_ts,
-                            )
-                        )
-
-                        # If we run the errback immediately, it may cancel our
-                        # loggingcontext while we are still in it, so instead we
-                        # schedule it for the next time round the reactor.
-                        #
-                        # (this also ensures that we don't get a stack overflow if we
-                        # has a massive queue of lookups waiting for this server).
-                        self.clock.call_later(
-                            0,
-                            verify_request.key_ready.errback,
-                            SynapseError(
-                                401,
-                                "Failed to find any key to satisfy %s" % (rq_str,),
-                                Codes.UNAUTHORIZED,
-                            ),
-                        )
-            except Exception as err:
-                # we don't really expect to get here, because any errors should already
-                # have been caught and logged. But if we do, let's log the error and make
-                # sure that all of the deferreds are resolved.
-                logger.error("Unexpected error in _get_server_verify_keys: %s", err)
-                with PreserveLoggingContext():
-                    for verify_request in remaining_requests:
-                        if not verify_request.key_ready.called:
-                            verify_request.key_ready.errback(err)
-
-        run_in_background(do_iterations)
-
-    async def _attempt_key_fetches_with_fetcher(
-        self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
-    ):
-        """Use a key fetcher to attempt to satisfy some key requests
+        for fetcher in self._key_fetchers:
+            if not missing_key_ids:
+                break
 
-        Args:
-            fetcher: fetcher to use to fetch the keys
-            remaining_requests: outstanding key requests.
-                Any successfully-completed requests will be removed from the list.
-        """
-        # The keys to fetch.
-        # server_name -> key_id -> min_valid_ts
-        missing_keys = defaultdict(dict)  # type: Dict[str, Dict[str, int]]
-
-        for verify_request in remaining_requests:
-            # any completed requests should already have been removed
-            assert not verify_request.key_ready.called
-            keys_for_server = missing_keys[verify_request.server_name]
-
-            for key_id in verify_request.key_ids:
-                # If we have several requests for the same key, then we only need to
-                # request that key once, but we should do so with the greatest
-                # min_valid_until_ts of the requests, so that we can satisfy all of
-                # the requests.
-                keys_for_server[key_id] = max(
-                    keys_for_server.get(key_id, -1),
-                    verify_request.minimum_valid_until_ts,
-                )
+            logger.debug("Getting keys from %s for %s", fetcher, verify_request)
+            keys = await fetcher.get_keys(
+                verify_request.server_name,
+                list(missing_key_ids),
+                verify_request.minimum_valid_until_ts,
+            )
 
-        results = await fetcher.get_keys(missing_keys)
+            for key_id, key in keys.items():
+                if not key:
+                    continue
 
-        completed = []
-        for verify_request in remaining_requests:
-            server_name = verify_request.server_name
+                # If we already have a result for the given key ID we keep the
+                # one with the highest `valid_until_ts`.
+                existing_key = found_keys.get(key_id)
+                if existing_key:
+                    if key.valid_until_ts <= existing_key.valid_until_ts:
+                        continue
 
-            # see if any of the keys we got this time are sufficient to
-            # complete this VerifyJsonRequest.
-            result_keys = results.get(server_name, {})
-            for key_id in verify_request.key_ids:
-                fetch_key_result = result_keys.get(key_id)
-                if not fetch_key_result:
-                    # we didn't get a result for this key
-                    continue
+                # We always store the returned key even if it doesn't the
+                # `minimum_valid_until_ts` requirement, as some verification
+                # requests may still be able to be satisfied by it.
+                #
+                # We still keep looking for the key from other fetchers in that
+                # case though.
+                found_keys[key_id] = key
 
-                if (
-                    fetch_key_result.valid_until_ts
-                    < verify_request.minimum_valid_until_ts
-                ):
-                    # key was not valid at this point
+                if key.valid_until_ts < verify_request.minimum_valid_until_ts:
                     continue
 
-                # we have a valid key for this request. If we run the callback
-                # immediately, it may cancel our loggingcontext while we are still in
-                # it, so instead we schedule it for the next time round the reactor.
-                #
-                # (this also ensures that we don't get a stack overflow if we had
-                # a massive queue of lookups waiting for this server).
-                logger.debug(
-                    "Found key %s:%s for %s",
-                    server_name,
-                    key_id,
-                    verify_request.request_name,
-                )
-                self.clock.call_later(
-                    0,
-                    verify_request.key_ready.callback,
-                    (server_name, key_id, fetch_key_result.verify_key),
-                )
-                completed.append(verify_request)
-                break
+                missing_key_ids.discard(key_id)
 
-        remaining_requests.difference_update(completed)
+        return found_keys
 
 
 class KeyFetcher(metaclass=abc.ABCMeta):
-    @abc.abstractmethod
+    def __init__(self, hs: "HomeServer"):
+        self._queue = BatchingQueue(
+            self.__class__.__name__, hs.get_clock(), self._fetch_keys
+        )
+
     async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
-    ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """
-        Args:
-            keys_to_fetch:
-                the keys to be fetched. server_name -> key_id -> min_valid_ts
+        self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+    ) -> Dict[str, FetchKeyResult]:
+        results = await self._queue.add_to_queue(
+            _FetchKeyRequest(
+                server_name=server_name,
+                key_ids=key_ids,
+                minimum_valid_until_ts=minimum_valid_until_ts,
+            )
+        )
+        return results.get(server_name, {})
 
-        Returns:
-            Map from server_name -> key_id -> FetchKeyResult
-        """
-        raise NotImplementedError
+    @abc.abstractmethod
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_FetchKeyRequest]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
+        pass
 
 
 class StoreKeyFetcher(KeyFetcher):
     """KeyFetcher impl which fetches keys from our data store"""
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        super().__init__(hs)
 
-    async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
-    ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """see KeyFetcher.get_keys"""
+        self.store = hs.get_datastore()
 
+    async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]):
         key_ids_to_fetch = (
-            (server_name, key_id)
-            for server_name, keys_for_server in keys_to_fetch.items()
-            for key_id in keys_for_server.keys()
+            (queue_value.server_name, key_id)
+            for queue_value in keys_to_fetch
+            for key_id in queue_value.key_ids
         )
 
         res = await self.store.get_server_verify_keys(key_ids_to_fetch)
@@ -578,6 +485,8 @@ class StoreKeyFetcher(KeyFetcher):
 
 class BaseV2KeyFetcher(KeyFetcher):
     def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
         self.store = hs.get_datastore()
         self.config = hs.config
 
@@ -685,10 +594,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         self.client = hs.get_federation_http_client()
         self.key_servers = self.config.key_servers
 
-    async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_FetchKeyRequest]
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """see KeyFetcher.get_keys"""
+        """see KeyFetcher._fetch_keys"""
 
         async def get_key(key_server: TrustedKeyServer) -> Dict:
             try:
@@ -724,12 +633,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         return union_of_keys
 
     async def get_server_verify_key_v2_indirect(
-        self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+        self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
             keys_to_fetch:
-                the keys to be fetched. server_name -> key_id -> min_valid_ts
+                the keys to be fetched.
 
             key_server: notary server to query for the keys
 
@@ -743,7 +652,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         perspective_name = key_server.server_name
         logger.info(
             "Requesting keys %s from notary server %s",
-            keys_to_fetch.items(),
+            keys_to_fetch,
             perspective_name,
         )
 
@@ -753,11 +662,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 path="/_matrix/key/v2/query",
                 data={
                     "server_keys": {
-                        server_name: {
-                            key_id: {"minimum_valid_until_ts": min_valid_ts}
-                            for key_id, min_valid_ts in server_keys.items()
+                        queue_value.server_name: {
+                            key_id: {
+                                "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
+                            }
+                            for key_id in queue_value.key_ids
                         }
-                        for server_name, server_keys in keys_to_fetch.items()
+                        for queue_value in keys_to_fetch
                     }
                 },
             )
@@ -858,7 +769,20 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         self.client = hs.get_federation_http_client()
 
     async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
+        self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+    ) -> Dict[str, FetchKeyResult]:
+        results = await self._queue.add_to_queue(
+            _FetchKeyRequest(
+                server_name=server_name,
+                key_ids=key_ids,
+                minimum_valid_until_ts=minimum_valid_until_ts,
+            ),
+            key=server_name,
+        )
+        return results.get(server_name, {})
+
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_FetchKeyRequest]
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
@@ -871,8 +795,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
 
         results = {}
 
-        async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
-            server_name, key_ids = key_to_fetch_item
+        async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None:
+            server_name = key_to_fetch_item.server_name
+            key_ids = key_to_fetch_item.key_ids
+
             try:
                 keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
                 results[server_name] = keys
@@ -883,7 +809,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             except Exception:
                 logger.exception("Error getting keys %s from %s", key_ids, server_name)
 
-        await yieldable_gather_results(get_key, keys_to_fetch.items())
+        await yieldable_gather_results(get_key, keys_to_fetch)
         return results
 
     async def get_server_verify_key_v2_direct(
@@ -955,37 +881,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             keys.update(response_keys)
 
         return keys
-
-
-async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
-    """Waits for the key to become available, and then performs a verification
-
-    Args:
-        verify_request:
-
-    Raises:
-        SynapseError if there was a problem performing the verification
-    """
-    server_name = verify_request.server_name
-    with PreserveLoggingContext():
-        _, key_id, verify_key = await verify_request.key_ready
-
-    json_object = verify_request.get_json_object()
-
-    try:
-        verify_signed_json(json_object, server_name, verify_key)
-    except SignatureVerifyException as e:
-        logger.debug(
-            "Error verifying signature for %s:%s:%s with key %s: %s",
-            server_name,
-            verify_key.alg,
-            verify_key.version,
-            encode_verify_key_base64(verify_key),
-            str(e),
-        )
-        raise SynapseError(
-            401,
-            "Invalid signature for server %s with key %s:%s: %s"
-            % (server_name, verify_key.alg, verify_key.version, str(e)),
-            Codes.UNAUTHORIZED,
-        )
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 59e0a434dc..5756fcb551 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -37,6 +37,7 @@ from synapse.http.servlet import (
 )
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
+    SynapseTags,
     start_active_span,
     start_active_span_from_request,
     tags,
@@ -151,7 +152,9 @@ class Authenticator:
             )
 
         await self.keyring.verify_json_for_server(
-            origin, json_request, now, "Incoming request"
+            origin,
+            json_request,
+            now,
         )
 
         logger.debug("Request from %s", origin)
@@ -314,7 +317,7 @@ class BaseFederationServlet:
                 raise
 
             request_tags = {
-                "request_id": request.get_request_id(),
+                SynapseTags.REQUEST_ID: request.get_request_id(),
                 tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
                 tags.HTTP_METHOD: request.get_method(),
                 tags.HTTP_URL: request.get_redacted_uri(),
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index d2fc8be5f5..ff8372c4e9 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -108,7 +108,9 @@ class GroupAttestationSigning:
 
         assert server_name is not None
         await self.keyring.verify_json_for_server(
-            server_name, attestation, now, "Group attestation"
+            server_name,
+            attestation,
+            now,
         )
 
     def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bf11315251..49ed7cabcc 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -577,7 +577,9 @@ class FederationHandler(BaseHandler):
 
         # Fetch the state events from the DB, and check we have the auth events.
         event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
-        auth_events_in_store = await self.store.have_seen_events(auth_event_ids)
+        auth_events_in_store = await self.store.have_seen_events(
+            room_id, auth_event_ids
+        )
 
         # Check for missing events. We handle state and auth event seperately,
         # as we want to pull the state from the DB, but we don't for the auth
@@ -610,7 +612,7 @@ class FederationHandler(BaseHandler):
 
             if missing_auth_events:
                 auth_events_in_store = await self.store.have_seen_events(
-                    missing_auth_events
+                    room_id, missing_auth_events
                 )
                 missing_auth_events.difference_update(auth_events_in_store)
 
@@ -710,7 +712,7 @@ class FederationHandler(BaseHandler):
 
         missing_auth_events = set(auth_event_ids) - fetched_events.keys()
         missing_auth_events.difference_update(
-            await self.store.have_seen_events(missing_auth_events)
+            await self.store.have_seen_events(room_id, missing_auth_events)
         )
         logger.debug("We are also missing %i auth events", len(missing_auth_events))
 
@@ -2475,7 +2477,7 @@ class FederationHandler(BaseHandler):
         #
         # we start by checking if they are in the store, and then try calling /event_auth/.
         if missing_auth:
-            have_events = await self.store.have_seen_events(missing_auth)
+            have_events = await self.store.have_seen_events(event.room_id, missing_auth)
             logger.debug("Events %s are in the store", have_events)
             missing_auth.difference_update(have_events)
 
@@ -2494,7 +2496,7 @@ class FederationHandler(BaseHandler):
                     return context
 
                 seen_remotes = await self.store.have_seen_events(
-                    [e.event_id for e in remote_auth_chain]
+                    event.room_id, [e.event_id for e in remote_auth_chain]
                 )
 
                 for e in remote_auth_chain:
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index abd9ddecca..046dba6fd8 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -26,7 +26,6 @@ from synapse.api.constants import (
     HistoryVisibility,
     Membership,
 )
-from synapse.api.errors import AuthError
 from synapse.events import EventBase
 from synapse.events.utils import format_event_for_client_v2
 from synapse.types import JsonDict
@@ -456,16 +455,16 @@ class SpaceSummaryHandler:
                     return True
 
             # Otherwise, check if they should be allowed access via membership in a space.
-            try:
-                await self._event_auth_handler.check_restricted_join_rules(
-                    state_ids, room_version, requester, member_event
+            if self._event_auth_handler.has_restricted_join_rules(
+                state_ids, room_version
+            ):
+                allowed_spaces = (
+                    await self._event_auth_handler.get_spaces_that_allow_join(state_ids)
                 )
-            except AuthError:
-                # The user doesn't have access due to spaces, but might have access
-                # another way. Keep trying.
-                pass
-            else:
-                return True
+                if await self._event_auth_handler.is_user_in_rooms(
+                    allowed_spaces, requester
+                ):
+                    return True
 
         # If this is a request over federation, check if the host is in the room or
         # is in one of the spaces specified via the join rules.
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 02ae36d2de..e607527ad1 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -464,7 +464,7 @@ class SyncHandler:
                 # ensure that we always include current state in the timeline
                 current_state_ids = frozenset()  # type: FrozenSet[str]
                 if any(e.is_state() for e in recents):
-                    current_state_ids_map = await self.state.get_current_state_ids(
+                    current_state_ids_map = await self.store.get_current_state_ids(
                         room_id
                     )
                     current_state_ids = frozenset(current_state_ids_map.values())
@@ -524,7 +524,7 @@ class SyncHandler:
                 # ensure that we always include current state in the timeline
                 current_state_ids = frozenset()
                 if any(e.is_state() for e in loaded_recents):
-                    current_state_ids_map = await self.state.get_current_state_ids(
+                    current_state_ids_map = await self.store.get_current_state_ids(
                         room_id
                     )
                     current_state_ids = frozenset(current_state_ids_map.values())
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 31897546a9..3f4f2411fc 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -15,6 +15,9 @@
 """ This module contains base REST classes for constructing REST servlets. """
 
 import logging
+from typing import Iterable, List, Optional, Union, overload
+
+from typing_extensions import Literal
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.util import json_decoder
@@ -107,12 +110,11 @@ def parse_boolean_from_args(args, name, default=None, required=False):
 
 def parse_string(
     request,
-    name,
-    default=None,
-    required=False,
-    allowed_values=None,
-    param_type="string",
-    encoding="ascii",
+    name: Union[bytes, str],
+    default: Optional[str] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
 ):
     """
     Parse a string parameter from the request query string.
@@ -122,18 +124,17 @@ def parse_string(
 
     Args:
         request: the twisted HTTP request.
-        name (bytes|unicode): the name of the query parameter.
-        default (bytes|unicode|None): value to use if the parameter is absent,
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
             defaults to None. Must be bytes if encoding is None.
-        required (bool): whether to raise a 400 SynapseError if the
+        required: whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
-        allowed_values (list[bytes|unicode]): List of allowed values for the
+        allowed_values: List of allowed values for the
             string, or None if any value is allowed, defaults to None. Must be
             the same type as name, if given.
-        encoding (str|None): The encoding to decode the string content with.
-
+        encoding : The encoding to decode the string content with.
     Returns:
-        bytes/unicode|None: A string value or the default. Unicode if encoding
+        A string value or the default. Unicode if encoding
         was given, bytes otherwise.
 
     Raises:
@@ -142,45 +143,105 @@ def parse_string(
             is not one of those allowed values.
     """
     return parse_string_from_args(
-        request.args, name, default, required, allowed_values, param_type, encoding
+        request.args, name, default, required, allowed_values, encoding
     )
 
 
-def parse_string_from_args(
-    args,
-    name,
-    default=None,
-    required=False,
-    allowed_values=None,
-    param_type="string",
-    encoding="ascii",
-):
+def _parse_string_value(
+    value: Union[str, bytes],
+    allowed_values: Optional[Iterable[str]],
+    name: str,
+    encoding: Optional[str],
+) -> Union[str, bytes]:
+    if encoding:
+        try:
+            value = value.decode(encoding)
+        except ValueError:
+            raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+
+    if allowed_values is not None and value not in allowed_values:
+        message = "Query parameter %r must be one of [%s]" % (
+            name,
+            ", ".join(repr(v) for v in allowed_values),
+        )
+        raise SynapseError(400, message)
+    else:
+        return value
+
+
+@overload
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Literal[None] = None,
+) -> Optional[List[bytes]]:
+    ...
+
+
+@overload
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> Optional[List[str]]:
+    ...
+
+
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
+) -> Optional[List[Union[bytes, str]]]:
+    """
+    Parse a string parameter from the request query string list.
+
+    If encoding is not None, the content of the query param will be
+    decoded to Unicode using the encoding, otherwise it will be encoded
+
+    Args:
+        args: the twisted HTTP request.args list.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
+        required : whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+        allowed_values (list[bytes|unicode]): List of allowed values for the
+            string, or None if any value is allowed, defaults to None. Must be
+            the same type as name, if given.
+        encoding: The encoding to decode the string content with.
+
+    Returns:
+        A string value or the default. Unicode if encoding
+        was given, bytes otherwise.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
 
     if not isinstance(name, bytes):
         name = name.encode("ascii")
 
     if name in args:
-        value = args[name][0]
-
-        if encoding:
-            try:
-                value = value.decode(encoding)
-            except ValueError:
-                raise SynapseError(
-                    400, "Query parameter %r must be %s" % (name, encoding)
-                )
-
-        if allowed_values is not None and value not in allowed_values:
-            message = "Query parameter %r must be one of [%s]" % (
-                name,
-                ", ".join(repr(v) for v in allowed_values),
-            )
-            raise SynapseError(400, message)
-        else:
-            return value
+        values = args[name]
+
+        return [
+            _parse_string_value(value, allowed_values, name=name, encoding=encoding)
+            for value in values
+        ]
     else:
         if required:
-            message = "Missing %s query parameter %r" % (param_type, name)
+            message = "Missing string query parameter %r" % (name)
             raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
 
@@ -190,6 +251,55 @@ def parse_string_from_args(
             return default
 
 
+def parse_string_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[str] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
+) -> Optional[Union[bytes, str]]:
+    """
+    Parse the string parameter from the request query string list
+    and return the first result.
+
+    If encoding is not None, the content of the query param will be
+    decoded to Unicode using the encoding, otherwise it will be encoded
+
+    Args:
+        args: the twisted HTTP request.args list.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
+        required: whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+        allowed_values: List of allowed values for the
+            string, or None if any value is allowed, defaults to None. Must be
+            the same type as name, if given.
+        encoding: The encoding to decode the string content with.
+
+    Returns:
+        A string value or the default. Unicode if encoding
+        was given, bytes otherwise.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
+
+    strings = parse_strings_from_args(
+        args,
+        name,
+        default=[default],
+        required=required,
+        allowed_values=allowed_values,
+        encoding=encoding,
+    )
+
+    return strings[0]
+
+
 def parse_json_value_from_request(request, allow_empty_body=False):
     """Parse a JSON value from the body of a twisted HTTP request.
 
@@ -215,7 +325,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
     try:
         content = json_decoder.decode(content_bytes.decode("utf-8"))
     except Exception as e:
-        logger.warning("Unable to parse JSON: %s", e)
+        logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
         raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
 
     return content
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index fba2fa3904..f64845b80c 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -265,6 +265,12 @@ class SynapseTags:
     # Whether the sync response has new data to be returned to the client.
     SYNC_RESULT = "sync.new_data"
 
+    # incoming HTTP request ID  (as written in the logs)
+    REQUEST_ID = "request_id"
+
+    # HTTP request tag (used to distinguish full vs incremental syncs, etc)
+    REQUEST_TAG = "request_tag"
+
 
 # Block everything by default
 # A regex which matches the server_names to expose traces for.
@@ -588,7 +594,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
 
     span = opentracing.tracer.active_span
     carrier = {}  # type: Dict[str, str]
-    opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
+    opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
         headers.addRawHeaders(key, value)
@@ -625,7 +631,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
     span = opentracing.tracer.active_span
 
     carrier = {}  # type: Dict[str, str]
-    opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
+    opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
         headers[key.encode()] = [value.encode()]
@@ -659,7 +665,7 @@ def inject_active_span_text_map(carrier, destination, check_destination=True):
         return
 
     opentracing.tracer.inject(
-        opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+        opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
     )
 
 
@@ -681,7 +687,7 @@ def get_active_span_text_map(destination=None):
 
     carrier = {}  # type: Dict[str, str]
     opentracing.tracer.inject(
-        opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+        opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
     )
 
     return carrier
@@ -696,7 +702,7 @@ def active_span_context_as_string():
     carrier = {}  # type: Dict[str, str]
     if opentracing:
         opentracing.tracer.inject(
-            opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+            opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
         )
     return json_encoder.encode(carrier)
 
@@ -824,7 +830,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
         return
 
     request_tags = {
-        "request_id": request.get_request_id(),
+        SynapseTags.REQUEST_ID: request.get_request_id(),
         tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
         tags.HTTP_METHOD: request.get_method(),
         tags.HTTP_URL: request.get_redacted_uri(),
@@ -833,9 +839,9 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
 
     request_name = request.request_metrics.name
     if extract_context:
-        scope = start_active_span_from_request(request, request_name, tags=request_tags)
+        scope = start_active_span_from_request(request, request_name)
     else:
-        scope = start_active_span(request_name, tags=request_tags)
+        scope = start_active_span(request_name)
 
     with scope:
         try:
@@ -845,4 +851,11 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
             # with JsonResource).
             scope.span.set_operation_name(request.request_metrics.name)
 
-            scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
+            # set the tags *after* the servlet completes, in case it decided to
+            # prioritise the span (tags will get dropped on unprioritised spans)
+            request_tags[
+                SynapseTags.REQUEST_TAG
+            ] = request.request_metrics.start_context.tag
+
+            for k, v in request_tags.items():
+                scope.span.set_tag(k, v)
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 714caf84c3..0d6d643d35 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -22,7 +22,11 @@ from prometheus_client.core import REGISTRY, Counter, Gauge
 from twisted.internet import defer
 
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
-from synapse.logging.opentracing import noop_context_manager, start_active_span
+from synapse.logging.opentracing import (
+    SynapseTags,
+    noop_context_manager,
+    start_active_span,
+)
 from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
@@ -202,7 +206,9 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
             try:
                 ctx = noop_context_manager()
                 if bg_start_span:
-                    ctx = start_active_span(desc, tags={"request_id": str(context)})
+                    ctx = start_active_span(
+                        desc, tags={SynapseTags.REQUEST_ID: str(context)}
+                    )
                 with ctx:
                     return await maybe_awaitable(func(*args, **kwargs))
             except Exception:
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index d6d55893af..70286b0ff7 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1061,15 +1061,15 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
     RoomTypingRestServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
     RoomSpaceSummaryRestServlet(hs).register(http_server)
+    RoomEventServlet(hs).register(http_server)
+    JoinedRoomsRestServlet(hs).register(http_server)
+    RoomAliasListServlet(hs).register(http_server)
+    SearchRestServlet(hs).register(http_server)
 
     # Some servlets only get registered for the main process.
     if not is_worker:
         RoomCreateRestServlet(hs).register(http_server)
         RoomForgetRestServlet(hs).register(http_server)
-        SearchRestServlet(hs).register(http_server)
-        JoinedRoomsRestServlet(hs).register(http_server)
-        RoomEventServlet(hs).register(http_server)
-        RoomAliasListServlet(hs).register(http_server)
 
 
 def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index 2c169abbf3..07ea39a8a3 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -16,11 +16,7 @@ import logging
 from http import HTTPStatus
 
 from synapse.api.errors import Codes, SynapseError
-from synapse.http.servlet import (
-    RestServlet,
-    assert_params_in_dict,
-    parse_json_object_from_request,
-)
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 
 from ._base import client_patterns
 
@@ -42,15 +38,14 @@ class ReportEventRestServlet(RestServlet):
         user_id = requester.user.to_string()
 
         body = parse_json_object_from_request(request)
-        assert_params_in_dict(body, ("reason", "score"))
 
-        if not isinstance(body["reason"], str):
+        if not isinstance(body.get("reason", ""), str):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Param 'reason' must be a string",
                 Codes.BAD_JSON,
             )
-        if not isinstance(body["score"], int):
+        if not isinstance(body.get("score", 0), int):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Param 'score' must be an integer",
@@ -61,7 +56,7 @@ class ReportEventRestServlet(RestServlet):
             room_id=room_id,
             event_id=event_id,
             user_id=user_id,
-            reason=body["reason"],
+            reason=body.get("reason"),
             content=body,
             received_ts=self.clock.time_msec(),
         )
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index aba1734a55..d56a1ae482 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
 from synapse.util import json_decoder
+from synapse.util.async_helpers import yieldable_gather_results
 
 logger = logging.getLogger(__name__)
 
@@ -210,7 +211,13 @@ class RemoteKey(DirectServeJsonResource):
         # If there is a cache miss, request the missing keys, then recurse (and
         # ensure the result is sent).
         if cache_misses and query_remote_on_cache_miss:
-            await self.fetcher.get_keys(cache_misses)
+            await yieldable_gather_results(
+                lambda t: self.fetcher.get_keys(*t),
+                (
+                    (server_name, list(keys), 0)
+                    for server_name, keys in cache_misses.items()
+                ),
+            )
             await self.query_keys(request, query, query_remote_on_cache_miss=False)
         else:
             signed_keys = []
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index f7872501a0..c57ae5ef15 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -168,6 +168,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         backfilled,
     ):
         self._invalidate_get_event_cache(event_id)
+        self.have_seen_event.invalidate((room_id, event_id))
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6963bbf7f4..403a5ddaba 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -22,6 +22,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Set,
     Tuple,
     overload,
 )
@@ -55,7 +56,7 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
@@ -1045,32 +1046,74 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
-    async def have_seen_events(self, event_ids):
+    async def have_seen_events(
+        self, room_id: str, event_ids: Iterable[str]
+    ) -> Set[str]:
         """Given a list of event ids, check if we have already processed them.
 
+        The room_id is only used to structure the cache (so that it can later be
+        invalidated by room_id) - there is no guarantee that the events are actually
+        in the room in question.
+
         Args:
-            event_ids (iterable[str]):
+            room_id: Room we are polling
+            event_ids: events we are looking for
 
         Returns:
             set[str]: The events we have already seen.
         """
+        res = await self._have_seen_events_dict(
+            (room_id, event_id) for event_id in event_ids
+        )
+        return {eid for ((_rid, eid), have_event) in res.items() if have_event}
+
+    @cachedList("have_seen_event", "keys")
+    async def _have_seen_events_dict(
+        self, keys: Iterable[Tuple[str, str]]
+    ) -> Dict[Tuple[str, str], bool]:
+        """Helper for have_seen_events
+
+        Returns:
+             a dict {(room_id, event_id)-> bool}
+        """
         # if the event cache contains the event, obviously we've seen it.
-        results = {x for x in event_ids if self._get_event_cache.contains(x)}
 
-        def have_seen_events_txn(txn, chunk):
-            sql = "SELECT event_id FROM events as e WHERE "
+        cache_results = {
+            (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
+        }
+        results = {x: True for x in cache_results}
+
+        def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+            # we deliberately do *not* query the database for room_id, to make the
+            # query an index-only lookup on `events_event_id_key`.
+            #
+            # We therefore pull the events from the database into a set...
+
+            sql = "SELECT event_id FROM events AS e WHERE "
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "e.event_id", chunk
+                txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk]
             )
             txn.execute(sql + clause, args)
-            results.update(row[0] for row in txn)
+            found_events = {eid for eid, in txn}
 
-        for chunk in batch_iter((x for x in event_ids if x not in results), 100):
+            # ... and then we can update the results for each row in the batch
+            results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk})
+
+        # each batch requires its own index scan, so we make the batches as big as
+        # possible.
+        for chunk in batch_iter((k for k in keys if k not in cache_results), 500):
             await self.db_pool.runInteraction(
                 "have_seen_events", have_seen_events_txn, chunk
             )
+
         return results
 
+    @cached(max_entries=100000, tree=True)
+    async def have_seen_event(self, room_id: str, event_id: str):
+        # this only exists for the benefit of the @cachedList descriptor on
+        # _have_seen_events_dict
+        raise NotImplementedError()
+
     def _get_current_state_event_counts_txn(self, txn, room_id):
         """
         See get_current_state_event_counts.
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 8f83748b5e..7fb7780d0f 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -16,14 +16,14 @@ import logging
 from typing import Any, List, Set, Tuple
 
 from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.state import StateGroupWorkerStore
 from synapse.types import RoomStreamToken
 
 logger = logging.getLogger(__name__)
 
 
-class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
+class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
     async def purge_history(
         self, room_id: str, token: str, delete_local_events: bool
     ) -> Set[int]:
@@ -203,8 +203,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
             "DELETE FROM event_to_state_groups "
             "WHERE event_id IN (SELECT event_id from events_to_purge)"
         )
-        for event_id, _ in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
 
         # Delete all remote non-state events
         for table in (
@@ -283,6 +281,20 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         # so make sure to keep this actually last.
         txn.execute("DROP TABLE events_to_purge")
 
+        for event_id, should_delete in event_rows:
+            self._invalidate_cache_and_stream(
+                txn, self._get_state_group_for_event, (event_id,)
+            )
+
+            # XXX: This is racy, since have_seen_events could be called between the
+            #    transaction completing and the invalidation running. On the other hand,
+            #    that's no different to calling `have_seen_events` just before the
+            #    event is deleted from the database.
+            if should_delete:
+                self._invalidate_cache_and_stream(
+                    txn, self.have_seen_event, (room_id, event_id)
+                )
+
         logger.info("[purge] done")
 
         return referenced_state_groups
@@ -422,7 +434,11 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         #       index on them. In any case we should be clearing out 'stream' tables
         #       periodically anyway (#5888)
 
-        # TODO: we could probably usefully do a bunch of cache invalidation here
+        # TODO: we could probably usefully do a bunch more cache invalidation here
+
+        # XXX: as with purge_history, this is racy, but no worse than other races
+        #   that already exist.
+        self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,))
 
         logger.info("[purge] done")
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 5f38634f48..0cf450f81d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         room_id: str,
         event_id: str,
         user_id: str,
-        reason: str,
+        reason: Optional[str],
         content: JsonDict,
         received_ts: int,
     ) -> None:
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 2775dfd880..745c295d3b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import time
+from typing import Dict, List
 from unittest.mock import Mock
 
 import attr
@@ -21,7 +22,6 @@ import signedjson.sign
 from nacl.signing import SigningKey
 from signedjson.key import encode_verify_key_base64, get_verify_key
 
-from twisted.internet import defer
 from twisted.internet.defer import Deferred, ensureDeferred
 
 from synapse.api.errors import SynapseError
@@ -92,23 +92,23 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # deferred completes.
         first_lookup_deferred = Deferred()
 
-        async def first_lookup_fetch(keys_to_fetch):
-            self.assertEquals(current_context().request.id, "context_11")
-            self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
+        async def first_lookup_fetch(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            # self.assertEquals(current_context().request.id, "context_11")
+            self.assertEqual(server_name, "server10")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 0)
 
             await make_deferred_yieldable(first_lookup_deferred)
-            return {
-                "server10": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
-                }
-            }
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
 
         mock_fetcher.get_keys.side_effect = first_lookup_fetch
 
         async def first_lookup():
             with LoggingContext("context_11", request=FakeRequest("context_11")):
                 res_deferreds = kr.verify_json_objects_for_server(
-                    [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
+                    [("server10", json1, 0), ("server11", {}, 0)]
                 )
 
                 # the unsigned json should be rejected pretty quickly
@@ -126,18 +126,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 
         d0 = ensureDeferred(first_lookup())
 
+        self.pump()
+
         mock_fetcher.get_keys.assert_called_once()
 
         # a second request for a server with outstanding requests
         # should block rather than start a second call
 
-        async def second_lookup_fetch(keys_to_fetch):
-            self.assertEquals(current_context().request.id, "context_12")
-            return {
-                "server10": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
-                }
-            }
+        async def second_lookup_fetch(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            # self.assertEquals(current_context().request.id, "context_12")
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
 
         mock_fetcher.get_keys.reset_mock()
         mock_fetcher.get_keys.side_effect = second_lookup_fetch
@@ -146,7 +146,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         async def second_lookup():
             with LoggingContext("context_12", request=FakeRequest("context_12")):
                 res_deferreds_2 = kr.verify_json_objects_for_server(
-                    [("server10", json1, 0, "test")]
+                    [
+                        (
+                            "server10",
+                            json1,
+                            0,
+                        )
+                    ]
                 )
                 res_deferreds_2[0].addBoth(self.check_context, None)
                 second_lookup_state[0] = 1
@@ -183,11 +189,11 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         signedjson.sign.sign_json(json1, "server9", key1)
 
         # should fail immediately on an unsigned object
-        d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
+        d = kr.verify_json_for_server("server9", {}, 0)
         self.get_failure(d, SynapseError)
 
         # should succeed on a signed object
-        d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
+        d = kr.verify_json_for_server("server9", json1, 500)
         # self.assertFalse(d.called)
         self.get_success(d)
 
@@ -214,24 +220,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         signedjson.sign.sign_json(json1, "server9", key1)
 
         # should fail immediately on an unsigned object
-        d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
+        d = kr.verify_json_for_server("server9", {}, 0)
         self.get_failure(d, SynapseError)
 
         # should fail on a signed object with a non-zero minimum_valid_until_ms,
         # as it tries to refetch the keys and fails.
-        d = _verify_json_for_server(
-            kr, "server9", json1, 500, "test signed non-zero min"
-        )
+        d = kr.verify_json_for_server("server9", json1, 500)
         self.get_failure(d, SynapseError)
 
         # We expect the keyring tried to refetch the key once.
         mock_fetcher.get_keys.assert_called_once_with(
-            {"server9": {get_key_id(key1): 500}}
+            "server9", [get_key_id(key1)], 500
         )
 
         # should succeed on a signed object with a 0 minimum_valid_until_ms
-        d = _verify_json_for_server(
-            kr, "server9", json1, 0, "test signed with zero min"
+        d = kr.verify_json_for_server(
+            "server9",
+            json1,
+            0,
         )
         self.get_success(d)
 
@@ -239,15 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """Two requests for the same key should be deduped."""
         key1 = signedjson.key.generate_signing_key(1)
 
-        async def get_keys(keys_to_fetch):
+        async def get_keys(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
             # there should only be one request object (with the max validity)
-            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+            self.assertEqual(server_name, "server1")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 1500)
 
-            return {
-                "server1": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                }
-            }
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
 
         mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock(side_effect=get_keys)
@@ -259,7 +265,14 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # the first request should succeed; the second should fail because the key
         # has expired
         results = kr.verify_json_objects_for_server(
-            [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
+            [
+                (
+                    "server1",
+                    json1,
+                    500,
+                ),
+                ("server1", json1, 1500),
+            ]
         )
         self.assertEqual(len(results), 2)
         self.get_success(results[0])
@@ -274,19 +287,21 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """If the first fetcher cannot provide a recent enough key, we fall back"""
         key1 = signedjson.key.generate_signing_key(1)
 
-        async def get_keys1(keys_to_fetch):
-            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return {
-                "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
-            }
-
-        async def get_keys2(keys_to_fetch):
-            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return {
-                "server1": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                }
-            }
+        async def get_keys1(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            self.assertEqual(server_name, "server1")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 1500)
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
+
+        async def get_keys2(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            self.assertEqual(server_name, "server1")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 1500)
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
 
         mock_fetcher1 = Mock()
         mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
@@ -298,7 +313,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         signedjson.sign.sign_json(json1, "server1", key1)
 
         results = kr.verify_json_objects_for_server(
-            [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
+            [
+                (
+                    "server1",
+                    json1,
+                    1200,
+                ),
+                (
+                    "server1",
+                    json1,
+                    1500,
+                ),
+            ]
         )
         self.assertEqual(len(results), 2)
         self.get_success(results[0])
@@ -349,9 +375,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.http_client.get_json.side_effect = get_json
 
-        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
-        k = keys[SERVER_NAME][testverifykey_id]
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
+        k = keys[testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
         self.assertEqual(k.verify_key.alg, "ed25519")
@@ -378,7 +403,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         # change the server name: the result should be ignored
         response["server_name"] = "OTHER_SERVER"
 
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
         self.assertEqual(keys, {})
 
 
@@ -465,10 +490,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
 
-        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
-        self.assertIn(SERVER_NAME, keys)
-        k = keys[SERVER_NAME][testverifykey_id]
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
+        self.assertIn(testverifykey_id, keys)
+        k = keys[testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
         self.assertEqual(k.verify_key.alg, "ed25519")
@@ -515,10 +539,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
 
-        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
-        self.assertIn(SERVER_NAME, keys)
-        k = keys[SERVER_NAME][testverifykey_id]
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
+        self.assertIn(testverifykey_id, keys)
+        k = keys[testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
         self.assertEqual(k.verify_key.alg, "ed25519")
@@ -559,14 +582,13 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         def get_key_from_perspectives(response):
             fetcher = PerspectivesKeyFetcher(self.hs)
-            keys_to_fetch = {SERVER_NAME: {"key1": 0}}
             self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
-            return self.get_success(fetcher.get_keys(keys_to_fetch))
+            return self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
 
         # start with a valid response so we can check we are testing the right thing
         response = build_response()
         keys = get_key_from_perspectives(response)
-        k = keys[SERVER_NAME][testverifykey_id]
+        k = keys[testverifykey_id]
         self.assertEqual(k.verify_key, testverifykey)
 
         # remove the perspectives server's signature
@@ -585,23 +607,3 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 def get_key_id(key):
     """Get the matrix ID tag for a given SigningKey or VerifyKey"""
     return "%s:%s" % (key.alg, key.version)
-
-
-@defer.inlineCallbacks
-def run_in_context(f, *args, **kwargs):
-    with LoggingContext("testctx"):
-        rv = yield f(*args, **kwargs)
-    return rv
-
-
-def _verify_json_for_server(kr, *args):
-    """thin wrapper around verify_json_for_server which makes sure it is wrapped
-    with the patched defer.inlineCallbacks.
-    """
-
-    @defer.inlineCallbacks
-    def v():
-        rv1 = yield kr.verify_json_for_server(*args)
-        return rv1
-
-    return run_in_context(v)
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 29341bc6e9..f15d1cf6f7 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -64,7 +64,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
                 user_tok=self.admin_user_tok,
             )
         for _ in range(5):
-            self._create_event_and_report(
+            self._create_event_and_report_without_parameters(
                 room_id=self.room_id2,
                 user_tok=self.admin_user_tok,
             )
@@ -378,6 +378,19 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
+    def _create_event_and_report_without_parameters(self, room_id, user_tok):
+        """Create and report an event, but omit reason and score"""
+        resp = self.helper.send(room_id, tok=user_tok)
+        event_id = resp["event_id"]
+
+        channel = self.make_request(
+            "POST",
+            "rooms/%s/report/%s" % (room_id, event_id),
+            json.dumps({}),
+            access_token=user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
     def _check_fields(self, content):
         """Checks that all attributes are present in an event report"""
         for c in content:
diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py
new file mode 100644
index 0000000000..1ec6b05e5b
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_report_event.py
@@ -0,0 +1,83 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import report_event
+
+from tests import unittest
+
+
+class ReportEventTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+        report_event.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        self.room_id = self.helper.create_room_as(
+            self.other_user, tok=self.other_user_tok, is_public=True
+        )
+        self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok)
+        resp = self.helper.send(self.room_id, tok=self.admin_user_tok)
+        self.event_id = resp["event_id"]
+        self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id)
+
+    def test_reason_str_and_score_int(self):
+        data = {"reason": "this makes me sad", "score": -100}
+        self._assert_status(200, data)
+
+    def test_no_reason(self):
+        data = {"score": 0}
+        self._assert_status(200, data)
+
+    def test_no_score(self):
+        data = {"reason": "this makes me sad"}
+        self._assert_status(200, data)
+
+    def test_no_reason_and_no_score(self):
+        data = {}
+        self._assert_status(200, data)
+
+    def test_reason_int_and_score_str(self):
+        data = {"reason": 10, "score": "string"}
+        self._assert_status(400, data)
+
+    def test_reason_zero_and_score_blank(self):
+        data = {"reason": 0, "score": ""}
+        self._assert_status(400, data)
+
+    def test_reason_and_score_null(self):
+        data = {"reason": None, "score": None}
+        self._assert_status(400, data)
+
+    def _assert_status(self, response_status, data):
+        channel = self.make_request(
+            "POST",
+            self.report_path,
+            json.dumps(data),
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(
+            response_status, int(channel.result["code"]), msg=channel.result["body"]
+        )
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 3b275bc23b..a75c0ea3f0 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -208,10 +208,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
         keyid = "ed25519:%s" % (testkey.version,)
 
         fetcher = PerspectivesKeyFetcher(self.hs2)
-        d = fetcher.get_keys({"targetserver": {keyid: 1000}})
+        d = fetcher.get_keys("targetserver", [keyid], 1000)
         res = self.get_success(d)
-        self.assertIn("targetserver", res)
-        keyres = res["targetserver"][keyid]
+        self.assertIn(keyid, res)
+        keyres = res[keyid]
         assert isinstance(keyres, FetchKeyResult)
         self.assertEqual(
             signedjson.key.encode_verify_key_base64(keyres.verify_key),
@@ -230,10 +230,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
         keyid = "ed25519:%s" % (testkey.version,)
 
         fetcher = PerspectivesKeyFetcher(self.hs2)
-        d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
+        d = fetcher.get_keys(self.hs.hostname, [keyid], 1000)
         res = self.get_success(d)
-        self.assertIn(self.hs.hostname, res)
-        keyres = res[self.hs.hostname][keyid]
+        self.assertIn(keyid, res)
+        keyres = res[keyid]
         assert isinstance(keyres, FetchKeyResult)
         self.assertEqual(
             signedjson.key.encode_verify_key_base64(keyres.verify_key),
@@ -247,10 +247,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
         keyid = "ed25519:%s" % (self.hs_signing_key.version,)
 
         fetcher = PerspectivesKeyFetcher(self.hs2)
-        d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
+        d = fetcher.get_keys(self.hs.hostname, [keyid], 1000)
         res = self.get_success(d)
-        self.assertIn(self.hs.hostname, res)
-        keyres = res[self.hs.hostname][keyid]
+        self.assertIn(keyid, res)
+        keyres = res[keyid]
         assert isinstance(keyres, FetchKeyResult)
         self.assertEqual(
             signedjson.key.encode_verify_key_base64(keyres.verify_key),
diff --git a/tests/storage/databases/__init__.py b/tests/storage/databases/__init__.py
new file mode 100644
index 0000000000..c24c7ecd92
--- /dev/null
+++ b/tests/storage/databases/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/storage/databases/main/__init__.py b/tests/storage/databases/main/__init__.py
new file mode 100644
index 0000000000..c24c7ecd92
--- /dev/null
+++ b/tests/storage/databases/main/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
new file mode 100644
index 0000000000..932970fd9a
--- /dev/null
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -0,0 +1,96 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+
+from synapse.logging.context import LoggingContext
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+
+from tests import unittest
+
+
+class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.store: EventsWorkerStore = hs.get_datastore()
+
+        # insert some test data
+        for rid in ("room1", "room2"):
+            self.get_success(
+                self.store.db_pool.simple_insert(
+                    "rooms",
+                    {"room_id": rid, "room_version": 4},
+                )
+            )
+
+        for idx, (rid, eid) in enumerate(
+            (
+                ("room1", "event10"),
+                ("room1", "event11"),
+                ("room1", "event12"),
+                ("room2", "event20"),
+            )
+        ):
+            self.get_success(
+                self.store.db_pool.simple_insert(
+                    "events",
+                    {
+                        "event_id": eid,
+                        "room_id": rid,
+                        "topological_ordering": idx,
+                        "stream_ordering": idx,
+                        "type": "test",
+                        "processed": True,
+                        "outlier": False,
+                    },
+                )
+            )
+            self.get_success(
+                self.store.db_pool.simple_insert(
+                    "event_json",
+                    {
+                        "event_id": eid,
+                        "room_id": rid,
+                        "json": json.dumps({"type": "test", "room_id": rid}),
+                        "internal_metadata": "{}",
+                        "format_version": 3,
+                    },
+                )
+            )
+
+    def test_simple(self):
+        with LoggingContext(name="test") as ctx:
+            res = self.get_success(
+                self.store.have_seen_events("room1", ["event10", "event19"])
+            )
+            self.assertEquals(res, {"event10"})
+
+            # that should result in a single db query
+            self.assertEquals(ctx.get_resource_usage().db_txn_count, 1)
+
+        # a second lookup of the same events should cause no queries
+        with LoggingContext(name="test") as ctx:
+            res = self.get_success(
+                self.store.have_seen_events("room1", ["event10", "event19"])
+            )
+            self.assertEquals(res, {"event10"})
+            self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
+
+    def test_query_via_event_cache(self):
+        # fetch an event into the event cache
+        self.get_success(self.store.get_event("event10"))
+
+        # looking it up should now cause no db hits
+        with LoggingContext(name="test") as ctx:
+            res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
+            self.assertEquals(res, {"event10"})
+            self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py
index edf29e5b96..07be57d72c 100644
--- a/tests/util/test_batching_queue.py
+++ b/tests/util/test_batching_queue.py
@@ -45,37 +45,32 @@ class BatchingQueueTestCase(TestCase):
         self._pending_calls.append((values, d))
         return await make_deferred_yieldable(d)
 
+    def _get_sample_with_name(self, metric, name) -> int:
+        """For a prometheus metric get the value of the sample that has a
+        matching "name" label.
+        """
+        for sample in metric.collect()[0].samples:
+            if sample.labels.get("name") == name:
+                return sample.value
+
+        self.fail("Found no matching sample")
+
     def _assert_metrics(self, queued, keys, in_flight):
         """Assert that the metrics are correct"""
 
-        self.assertEqual(len(number_queued.collect()), 1)
-        self.assertEqual(len(number_queued.collect()[0].samples), 1)
+        sample = self._get_sample_with_name(number_queued, self.queue._name)
         self.assertEqual(
-            number_queued.collect()[0].samples[0].labels,
-            {"name": self.queue._name},
-        )
-        self.assertEqual(
-            number_queued.collect()[0].samples[0].value,
+            sample,
             queued,
             "number_queued",
         )
 
-        self.assertEqual(len(number_of_keys.collect()), 1)
-        self.assertEqual(len(number_of_keys.collect()[0].samples), 1)
-        self.assertEqual(
-            number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
-        )
-        self.assertEqual(
-            number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys"
-        )
+        sample = self._get_sample_with_name(number_of_keys, self.queue._name)
+        self.assertEqual(sample, keys, "number_of_keys")
 
-        self.assertEqual(len(number_in_flight.collect()), 1)
-        self.assertEqual(len(number_in_flight.collect()[0].samples), 1)
-        self.assertEqual(
-            number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
-        )
+        sample = self._get_sample_with_name(number_in_flight, self.queue._name)
         self.assertEqual(
-            number_in_flight.collect()[0].samples[0].value,
+            sample,
             in_flight,
             "number_in_flight",
         )