summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9882.misc1
-rw-r--r--changelog.d/9902.feature1
-rw-r--r--changelog.d/9904.misc1
-rw-r--r--changelog.d/9910.bugfix1
-rw-r--r--changelog.d/9910.feature (renamed from changelog.d/9910.misc)0
-rw-r--r--changelog.d/9911.doc1
-rw-r--r--docs/sample_config.yaml9
-rw-r--r--synapse/app/generic_worker.py3
-rw-r--r--synapse/app/homeserver.py3
-rw-r--r--synapse/config/database.py1
-rw-r--r--synapse/config/server.py9
-rw-r--r--synapse/crypto/keyring.py618
-rw-r--r--synapse/federation/federation_base.py17
-rw-r--r--synapse/federation/federation_client.py69
-rw-r--r--synapse/federation/transport/client.py10
-rw-r--r--synapse/handlers/directory.py4
-rw-r--r--synapse/handlers/events.py2
-rw-r--r--synapse/handlers/federation.py21
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/sync.py6
-rw-r--r--synapse/http/matrixfederationclient.py22
-rw-r--r--synapse/metrics/__init__.py179
-rw-r--r--synapse/replication/tcp/external_cache.py36
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py9
-rw-r--r--synapse/state/__init__.py10
-rw-r--r--synapse/storage/_base.py1
-rw-r--r--synapse/storage/databases/main/roommember.py8
-rw-r--r--synapse/storage/databases/main/user_directory.py4
29 files changed, 583 insertions, 467 deletions
diff --git a/changelog.d/9882.misc b/changelog.d/9882.misc
new file mode 100644
index 0000000000..facfa31f38
--- /dev/null
+++ b/changelog.d/9882.misc
@@ -0,0 +1 @@
+Export jemalloc stats to Prometheus if it is being used.
diff --git a/changelog.d/9902.feature b/changelog.d/9902.feature
new file mode 100644
index 0000000000..4d9f324d4e
--- /dev/null
+++ b/changelog.d/9902.feature
@@ -0,0 +1 @@
+Add limits to how often Synapse will GC, ensuring that large servers do not end up GC thrashing if `gc_thresholds` has not been correctly set.
diff --git a/changelog.d/9904.misc b/changelog.d/9904.misc
new file mode 100644
index 0000000000..3db1e625ae
--- /dev/null
+++ b/changelog.d/9904.misc
@@ -0,0 +1 @@
+Time response time for external cache requests.
diff --git a/changelog.d/9910.bugfix b/changelog.d/9910.bugfix
new file mode 100644
index 0000000000..06d523fd46
--- /dev/null
+++ b/changelog.d/9910.bugfix
@@ -0,0 +1 @@
+Fix bug where user directory could get out of sync if room visibility and membership changed in quick succession.
diff --git a/changelog.d/9910.misc b/changelog.d/9910.feature
index 54165cce18..54165cce18 100644
--- a/changelog.d/9910.misc
+++ b/changelog.d/9910.feature
diff --git a/changelog.d/9911.doc b/changelog.d/9911.doc
new file mode 100644
index 0000000000..f7fd9f1ba9
--- /dev/null
+++ b/changelog.d/9911.doc
@@ -0,0 +1 @@
+Add `port` argument to the Postgres database sample config section.
\ No newline at end of file
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index e0350279ad..9e22696170 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -152,6 +152,14 @@ presence:
 #
 #gc_thresholds: [700, 10, 10]
 
+# The minimum time in seconds between each GC for a generation, regardless of
+# the GC thresholds. This ensures that we don't do GC too frequently.
+#
+# A value of `[1, 10, 30]` indicates that a second must pass between consecutive
+# generation 0 GCs, etc.
+#
+# gc_min_seconds_between: [1, 10, 30]
+
 # Set the limit on the returned events in the timeline in the get
 # and sync operations. The default value is 100. -1 means no upper limit.
 #
@@ -810,6 +818,7 @@ caches:
 #    password: secretpassword
 #    database: synapse
 #    host: localhost
+#    port: 5432
 #    cp_min: 5
 #    cp_max: 10
 #
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 1a15ceee81..a3fe9a3f38 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -455,6 +455,9 @@ def start(config_options):
 
     synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
+    if config.server.gc_seconds:
+        synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
+
     hs = GenericWorkerServer(
         config.server_name,
         config=config,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8e78134bbe..6a823da10d 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -342,6 +342,9 @@ def setup(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
+    if config.server.gc_seconds:
+        synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
+
     hs = SynapseHomeServer(
         config.server_name,
         config=config,
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 79a02706b4..c76ef1e1de 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -58,6 +58,7 @@ DEFAULT_CONFIG = """\
 #    password: secretpassword
 #    database: synapse
 #    host: localhost
+#    port: 5432
 #    cp_min: 5
 #    cp_max: 10
 #
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 21ca7b33e3..ca1c9711f8 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -572,6 +572,7 @@ class ServerConfig(Config):
             _warn_if_webclient_configured(self.listeners)
 
         self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
+        self.gc_seconds = read_gc_thresholds(config.get("gc_min_seconds_between", None))
 
         @attr.s
         class LimitRemoteRoomsConfig:
@@ -917,6 +918,14 @@ class ServerConfig(Config):
         #
         #gc_thresholds: [700, 10, 10]
 
+        # The minimum time in seconds between each GC for a generation, regardless of
+        # the GC thresholds. This ensures that we don't do GC too frequently.
+        #
+        # A value of `[1, 10, 30]` indicates that a second must pass between consecutive
+        # generation 0 GCs, etc.
+        #
+        # gc_min_seconds_between: [1, 10, 30]
+
         # Set the limit on the returned events in the timeline in the get
         # and sync operations. The default value is 100. -1 means no upper limit.
         #
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5f18ef7748..a8c8df2bad 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -16,8 +16,7 @@
 import abc
 import logging
 import urllib
-from collections import defaultdict
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from signedjson.key import (
@@ -42,17 +41,14 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.config.key import TrustedKeyServer
-from synapse.logging.context import (
-    PreserveLoggingContext,
-    make_deferred_yieldable,
-    preserve_fn,
-    run_in_background,
-)
+from synapse.events import EventBase
+from synapse.events.utils import prune_event_dict
+from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.keys import FetchKeyResult
 from synapse.types import JsonDict
 from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import yieldable_gather_results
-from synapse.util.metrics import Measure
+from synapse.util.async_helpers import Linearizer, yieldable_gather_results
 from synapse.util.retryutils import NotRetryingDestination
 
 if TYPE_CHECKING:
@@ -74,8 +70,6 @@ class VerifyJsonRequest:
         minimum_valid_until_ts: time at which we require the signing key to
             be valid. (0 implies we don't care)
 
-        request_name: The name of the request.
-
         key_ids: The set of key_ids to that could be used to verify the JSON object
 
         key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
@@ -88,20 +82,94 @@ class VerifyJsonRequest:
     """
 
     server_name = attr.ib(type=str)
-    json_object = attr.ib(type=JsonDict)
+    json_object_callback = attr.ib(type=Callable[[], JsonDict])
     minimum_valid_until_ts = attr.ib(type=int)
-    request_name = attr.ib(type=str)
-    key_ids = attr.ib(init=False, type=List[str])
-    key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
+    key_ids = attr.ib(type=List[str])
+
+    @staticmethod
+    def from_json_object(
+        server_name: str, minimum_valid_until_ms: int, json_object: JsonDict
+    ):
+        key_ids = signature_ids(json_object, server_name)
+        return VerifyJsonRequest(
+            server_name, lambda: json_object, minimum_valid_until_ms, key_ids
+        )
 
-    def __attrs_post_init__(self):
-        self.key_ids = signature_ids(self.json_object, self.server_name)
+    @staticmethod
+    def from_event(
+        server_name: str,
+        minimum_valid_until_ms: int,
+        event: EventBase,
+    ):
+        key_ids = list(event.signatures.get(server_name, []))
+        return VerifyJsonRequest(
+            server_name,
+            lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
+            minimum_valid_until_ms,
+            key_ids,
+        )
 
 
 class KeyLookupError(ValueError):
     pass
 
 
+@attr.s(slots=True)
+class _QueueValue:
+    server_name = attr.ib(type=str)
+    minimum_valid_until_ts = attr.ib(type=int)
+    key_ids = attr.ib(type=List[str])
+
+
+class _Queue:
+    def __init__(self, name, clock, process_items):
+        self._name = name
+        self._clock = clock
+        self._is_processing = False
+        self._next_values = []
+
+        self.process_items = process_items
+
+    async def add_to_queue(self, value: _QueueValue) -> Dict[str, FetchKeyResult]:
+        d = defer.Deferred()
+        self._next_values.append((value, d))
+
+        if self._is_processing:
+            return await d
+
+        run_as_background_process(self._name, self._unsafe_process)
+
+        return await d
+
+    async def _unsafe_process(self):
+        # We purposefully defer to the next loop.
+        await self._clock.sleep(0)
+
+        try:
+            if self._is_processing:
+                return
+
+            self._is_processing = True
+
+            while self._next_values:
+                next_values = self._next_values
+                self._next_values = []
+
+                try:
+                    values = [value for value, _ in next_values]
+                    results = await self.process_items(values)
+
+                    for value, deferred in next_values:
+                        deferred.callback(results.get(value.server_name, {}))
+
+                except Exception as e:
+                    for _, deferred in next_values:
+                        deferred.errback(e)
+
+        finally:
+            self._is_processing = False
+
+
 class Keyring:
     def __init__(
         self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
@@ -116,12 +184,7 @@ class Keyring:
             )
         self._key_fetchers = key_fetchers
 
-        # map from server name to Deferred. Has an entry for each server with
-        # an ongoing key download; the Deferred completes once the download
-        # completes.
-        #
-        # These are regular, logcontext-agnostic Deferreds.
-        self.key_downloads = {}  # type: Dict[str, defer.Deferred]
+        self._server_queue = Linearizer("keyring_server")
 
     def verify_json_for_server(
         self,
@@ -130,365 +193,150 @@ class Keyring:
         validity_time: int,
         request_name: str,
     ) -> defer.Deferred:
-        """Verify that a JSON object has been signed by a given server
-
-        Args:
-            server_name: name of the server which must have signed this object
-
-            json_object: object to be checked
-
-            validity_time: timestamp at which we require the signing key to
-                be valid. (0 implies we don't care)
-
-            request_name: an identifier for this json object (eg, an event id)
-                for logging.
-
-        Returns:
-            Deferred[None]: completes if the the object was correctly signed, otherwise
-                errbacks with an error
-        """
-        req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
-        requests = (req,)
-        return make_deferred_yieldable(self._verify_objects(requests)[0])
+        request = VerifyJsonRequest.from_json_object(
+            server_name,
+            validity_time,
+            json_object,
+        )
+        return defer.ensureDeferred(self._verify_object(request))
 
     def verify_json_objects_for_server(
         self, server_and_json: Iterable[Tuple[str, dict, int, str]]
     ) -> List[defer.Deferred]:
-        """Bulk verifies signatures of json objects, bulk fetching keys as
-        necessary.
-
-        Args:
-            server_and_json:
-                Iterable of (server_name, json_object, validity_time, request_name)
-                tuples.
-
-                validity_time is a timestamp at which the signing key must be
-                valid.
-
-                request_name is an identifier for this json object (eg, an event id)
-                for logging.
-
-        Returns:
-            List<Deferred[None]>: for each input triplet, a deferred indicating success
-                or failure to verify each json object's signature for the given
-                server_name. The deferreds run their callbacks in the sentinel
-                logcontext.
-        """
-        return self._verify_objects(
-            VerifyJsonRequest(server_name, json_object, validity_time, request_name)
+        return [
+            defer.ensureDeferred(
+                self._verify_object(
+                    VerifyJsonRequest.from_json_object(
+                        server_name,
+                        validity_time,
+                        json_object,
+                    )
+                )
+            )
             for server_name, json_object, validity_time, request_name in server_and_json
-        )
+        ]
 
-    def _verify_objects(
-        self, verify_requests: Iterable[VerifyJsonRequest]
+    def verify_events_for_server(
+        self, server_and_json: Iterable[Tuple[str, EventBase, int]]
     ) -> List[defer.Deferred]:
-        """Does the work of verify_json_[objects_]for_server
-
-
-        Args:
-            verify_requests: Iterable of verification requests.
-
-        Returns:
-            List<Deferred[None]>: for each input item, a deferred indicating success
-                or failure to verify each json object's signature for the given
-                server_name. The deferreds run their callbacks in the sentinel
-                logcontext.
-        """
-        # a list of VerifyJsonRequests which are awaiting a key lookup
-        key_lookups = []
-        handle = preserve_fn(_handle_key_deferred)
-
-        def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
-            """Process an entry in the request list
-
-            Adds a key request to key_lookups, and returns a deferred which
-            will complete or fail (in the sentinel context) when verification completes.
-            """
-            if not verify_request.key_ids:
-                return defer.fail(
-                    SynapseError(
-                        400,
-                        "Not signed by %s" % (verify_request.server_name,),
-                        Codes.UNAUTHORIZED,
+        return [
+            defer.ensureDeferred(
+                self._verify_object(
+                    VerifyJsonRequest.from_event(
+                        server_name,
+                        validity_time,
+                        event,
                     )
                 )
-
-            logger.debug(
-                "Verifying %s for %s with key_ids %s, min_validity %i",
-                verify_request.request_name,
-                verify_request.server_name,
-                verify_request.key_ids,
-                verify_request.minimum_valid_until_ts,
-            )
-
-            # add the key request to the queue, but don't start it off yet.
-            key_lookups.append(verify_request)
-
-            # now run _handle_key_deferred, which will wait for the key request
-            # to complete and then do the verification.
-            #
-            # We want _handle_key_request to log to the right context, so we
-            # wrap it with preserve_fn (aka run_in_background)
-            return handle(verify_request)
-
-        results = [process(r) for r in verify_requests]
-
-        if key_lookups:
-            run_in_background(self._start_key_lookups, key_lookups)
-
-        return results
-
-    async def _start_key_lookups(
-        self, verify_requests: List[VerifyJsonRequest]
-    ) -> None:
-        """Sets off the key fetches for each verify request
-
-        Once each fetch completes, verify_request.key_ready will be resolved.
-
-        Args:
-            verify_requests:
-        """
-
-        try:
-            # map from server name to a set of outstanding request ids
-            server_to_request_ids = {}  # type: Dict[str, Set[int]]
-
-            for verify_request in verify_requests:
-                server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids.setdefault(server_name, set()).add(request_id)
-
-            # Wait for any previous lookups to complete before proceeding.
-            await self.wait_for_previous_lookups(server_to_request_ids.keys())
-
-            # take out a lock on each of the servers by sticking a Deferred in
-            # key_downloads
-            for server_name in server_to_request_ids.keys():
-                self.key_downloads[server_name] = defer.Deferred()
-                logger.debug("Got key lookup lock on %s", server_name)
-
-            # When we've finished fetching all the keys for a given server_name,
-            # drop the lock by resolving the deferred in key_downloads.
-            def drop_server_lock(server_name):
-                d = self.key_downloads.pop(server_name)
-                d.callback(None)
-
-            def lookup_done(res, verify_request):
-                server_name = verify_request.server_name
-                server_requests = server_to_request_ids[server_name]
-                server_requests.remove(id(verify_request))
-
-                # if there are no more requests for this server, we can drop the lock.
-                if not server_requests:
-                    logger.debug("Releasing key lookup lock on %s", server_name)
-                    drop_server_lock(server_name)
-
-                return res
-
-            for verify_request in verify_requests:
-                verify_request.key_ready.addBoth(lookup_done, verify_request)
-
-            # Actually start fetching keys.
-            self._get_server_verify_keys(verify_requests)
-        except Exception:
-            logger.exception("Error starting key lookups")
-
-    async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
-        """Waits for any previous key lookups for the given servers to finish.
-
-        Args:
-            server_names: list of servers which we want to look up
-
-        Returns:
-            Resolves once all key lookups for the given servers have
-                completed. Follows the synapse rules of logcontext preservation.
-        """
-        loop_count = 1
-        while True:
-            wait_on = [
-                (server_name, self.key_downloads[server_name])
-                for server_name in server_names
-                if server_name in self.key_downloads
-            ]
-            if not wait_on:
-                break
-            logger.info(
-                "Waiting for existing lookups for %s to complete [loop %i]",
-                [w[0] for w in wait_on],
-                loop_count,
             )
-            with PreserveLoggingContext():
-                await defer.DeferredList((w[1] for w in wait_on))
-
-            loop_count += 1
-
-    def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
-        """Tries to find at least one key for each verify request
-
-        For each verify_request, verify_request.key_ready is called back with
-        params (server_name, key_id, VerifyKey) if a key is found, or errbacked
-        with a SynapseError if none of the keys are found.
+            for server_name, event, validity_time in server_and_json
+        ]
+
+    async def _verify_object(self, verify_request: VerifyJsonRequest):
+        # TODO: Use a batching thing.
+        with (await self._server_queue.queue(verify_request.server_name)):
+            found_keys: Dict[str, FetchKeyResult] = {}
+            missing_key_ids = set(verify_request.key_ids)
+            for fetcher in self._key_fetchers:
+                if not missing_key_ids:
+                    break
+
+                keys = await fetcher.get_keys(
+                    verify_request.server_name,
+                    list(missing_key_ids),
+                    verify_request.minimum_valid_until_ts,
+                )
 
-        Args:
-            verify_requests: list of verify requests
-        """
+                for key_id, key in keys.items():
+                    if not key:
+                        continue
 
-        remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
+                    if key.valid_until_ts < verify_request.minimum_valid_until_ts:
+                        continue
 
-        async def do_iterations():
-            try:
-                with Measure(self.clock, "get_server_verify_keys"):
-                    for f in self._key_fetchers:
-                        if not remaining_requests:
-                            return
-                        await self._attempt_key_fetches_with_fetcher(
-                            f, remaining_requests
-                        )
-
-                    # look for any requests which weren't satisfied
-                    while remaining_requests:
-                        verify_request = remaining_requests.pop()
-                        rq_str = (
-                            "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
-                            % (
-                                verify_request.server_name,
-                                verify_request.key_ids,
-                                verify_request.minimum_valid_until_ts,
-                            )
-                        )
-
-                        # If we run the errback immediately, it may cancel our
-                        # loggingcontext while we are still in it, so instead we
-                        # schedule it for the next time round the reactor.
-                        #
-                        # (this also ensures that we don't get a stack overflow if we
-                        # has a massive queue of lookups waiting for this server).
-                        self.clock.call_later(
-                            0,
-                            verify_request.key_ready.errback,
-                            SynapseError(
-                                401,
-                                "Failed to find any key to satisfy %s" % (rq_str,),
-                                Codes.UNAUTHORIZED,
-                            ),
-                        )
-            except Exception as err:
-                # we don't really expect to get here, because any errors should already
-                # have been caught and logged. But if we do, let's log the error and make
-                # sure that all of the deferreds are resolved.
-                logger.error("Unexpected error in _get_server_verify_keys: %s", err)
-                with PreserveLoggingContext():
-                    for verify_request in remaining_requests:
-                        if not verify_request.key_ready.called:
-                            verify_request.key_ready.errback(err)
-
-        run_in_background(do_iterations)
-
-    async def _attempt_key_fetches_with_fetcher(
-        self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
-    ):
-        """Use a key fetcher to attempt to satisfy some key requests
+                    existing_key = found_keys.get(key_id)
+                    if existing_key:
+                        if key.valid_until_ts <= existing_key.valid_until_ts:
+                            continue
 
-        Args:
-            fetcher: fetcher to use to fetch the keys
-            remaining_requests: outstanding key requests.
-                Any successfully-completed requests will be removed from the list.
-        """
-        # The keys to fetch.
-        # server_name -> key_id -> min_valid_ts
-        missing_keys = defaultdict(dict)  # type: Dict[str, Dict[str, int]]
+                    found_keys[key_id] = key
 
-        for verify_request in remaining_requests:
-            # any completed requests should already have been removed
-            assert not verify_request.key_ready.called
-            keys_for_server = missing_keys[verify_request.server_name]
+                missing_key_ids.difference_update(found_keys)
 
-            for key_id in verify_request.key_ids:
-                # If we have several requests for the same key, then we only need to
-                # request that key once, but we should do so with the greatest
-                # min_valid_until_ts of the requests, so that we can satisfy all of
-                # the requests.
-                keys_for_server[key_id] = max(
-                    keys_for_server.get(key_id, -1),
-                    verify_request.minimum_valid_until_ts,
+            if missing_key_ids:
+                raise SynapseError(
+                    400,
+                    "Missing keys for %s: %s"
+                    % (verify_request.server_name, missing_key_ids),
+                    Codes.UNAUTHORIZED,
                 )
 
-        results = await fetcher.get_keys(missing_keys)
-
-        completed = []
-        for verify_request in remaining_requests:
-            server_name = verify_request.server_name
-
-            # see if any of the keys we got this time are sufficient to
-            # complete this VerifyJsonRequest.
-            result_keys = results.get(server_name, {})
             for key_id in verify_request.key_ids:
-                fetch_key_result = result_keys.get(key_id)
-                if not fetch_key_result:
-                    # we didn't get a result for this key
-                    continue
-
-                if (
-                    fetch_key_result.valid_until_ts
-                    < verify_request.minimum_valid_until_ts
-                ):
-                    # key was not valid at this point
-                    continue
-
-                # we have a valid key for this request. If we run the callback
-                # immediately, it may cancel our loggingcontext while we are still in
-                # it, so instead we schedule it for the next time round the reactor.
-                #
-                # (this also ensures that we don't get a stack overflow if we had
-                # a massive queue of lookups waiting for this server).
-                logger.debug(
-                    "Found key %s:%s for %s",
-                    server_name,
-                    key_id,
-                    verify_request.request_name,
-                )
-                self.clock.call_later(
-                    0,
-                    verify_request.key_ready.callback,
-                    (server_name, key_id, fetch_key_result.verify_key),
-                )
-                completed.append(verify_request)
-                break
-
-        remaining_requests.difference_update(completed)
+                verify_key = found_keys[key_id].verify_key
+                try:
+                    json_object = verify_request.json_object_callback()
+                    verify_signed_json(
+                        json_object,
+                        verify_request.server_name,
+                        verify_key,
+                    )
+                except SignatureVerifyException as e:
+                    logger.debug(
+                        "Error verifying signature for %s:%s:%s with key %s: %s",
+                        verify_request.server_name,
+                        verify_key.alg,
+                        verify_key.version,
+                        encode_verify_key_base64(verify_key),
+                        str(e),
+                    )
+                    raise SynapseError(
+                        401,
+                        "Invalid signature for server %s with key %s:%s: %s"
+                        % (
+                            verify_request.server_name,
+                            verify_key.alg,
+                            verify_key.version,
+                            str(e),
+                        ),
+                        Codes.UNAUTHORIZED,
+                    )
 
 
 class KeyFetcher(metaclass=abc.ABCMeta):
-    @abc.abstractmethod
+    def __init__(self, hs: "HomeServer"):
+        self._queue = _Queue(self.__class__.__name__, hs.get_clock(), self._fetch_keys)
+
     async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
-    ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """
-        Args:
-            keys_to_fetch:
-                the keys to be fetched. server_name -> key_id -> min_valid_ts
+        self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+    ) -> Dict[str, FetchKeyResult]:
+        return await self._queue.add_to_queue(
+            _QueueValue(
+                server_name=server_name,
+                key_ids=key_ids,
+                minimum_valid_until_ts=minimum_valid_until_ts,
+            )
+        )
 
-        Returns:
-            Map from server_name -> key_id -> FetchKeyResult
-        """
-        raise NotImplementedError
+    @abc.abstractmethod
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_QueueValue]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
+        pass
 
 
 class StoreKeyFetcher(KeyFetcher):
     """KeyFetcher impl which fetches keys from our data store"""
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        super().__init__(hs)
 
-    async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
-    ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """see KeyFetcher.get_keys"""
+        self.store = hs.get_datastore()
 
+    async def _fetch_keys(self, keys_to_fetch: List[_QueueValue]):
         key_ids_to_fetch = (
-            (server_name, key_id)
-            for server_name, keys_for_server in keys_to_fetch.items()
-            for key_id in keys_for_server.keys()
+            (queue_value.server_name, key_id)
+            for queue_value in keys_to_fetch
+            for key_id in queue_value.key_ids
         )
 
         res = await self.store.get_server_verify_keys(key_ids_to_fetch)
@@ -500,6 +348,8 @@ class StoreKeyFetcher(KeyFetcher):
 
 class BaseV2KeyFetcher(KeyFetcher):
     def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
         self.store = hs.get_datastore()
         self.config = hs.config
 
@@ -607,10 +457,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         self.client = hs.get_federation_http_client()
         self.key_servers = self.config.key_servers
 
-    async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_QueueValue]
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """see KeyFetcher.get_keys"""
+        """see KeyFetcher._fetch_keys"""
 
         async def get_key(key_server: TrustedKeyServer) -> Dict:
             try:
@@ -646,12 +496,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         return union_of_keys
 
     async def get_server_verify_key_v2_indirect(
-        self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+        self, keys_to_fetch: List[_QueueValue], key_server: TrustedKeyServer
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
             keys_to_fetch:
-                the keys to be fetched. server_name -> key_id -> min_valid_ts
+                the keys to be fetched.
 
             key_server: notary server to query for the keys
 
@@ -665,7 +515,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         perspective_name = key_server.server_name
         logger.info(
             "Requesting keys %s from notary server %s",
-            keys_to_fetch.items(),
+            keys_to_fetch,
             perspective_name,
         )
 
@@ -675,11 +525,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 path="/_matrix/key/v2/query",
                 data={
                     "server_keys": {
-                        server_name: {
-                            key_id: {"minimum_valid_until_ts": min_valid_ts}
-                            for key_id, min_valid_ts in server_keys.items()
+                        queue_value.server_name: {
+                            key_id: {
+                                "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
+                            }
+                            for key_id in queue_value.key_ids
                         }
-                        for server_name, server_keys in keys_to_fetch.items()
+                        for queue_value in keys_to_fetch
                     }
                 },
             )
@@ -779,8 +631,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         self.clock = hs.get_clock()
         self.client = hs.get_federation_http_client()
 
-    async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_QueueValue]
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
@@ -793,8 +645,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
 
         results = {}
 
-        async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
-            server_name, key_ids = key_to_fetch_item
+        async def get_key(key_to_fetch_item: _QueueValue) -> None:
+            server_name = key_to_fetch_item.server_name
+            key_ids = key_to_fetch_item.key_ids
+
             try:
                 keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
                 results[server_name] = keys
@@ -805,7 +659,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             except Exception:
                 logger.exception("Error getting keys %s from %s", key_ids, server_name)
 
-        await yieldable_gather_results(get_key, keys_to_fetch.items())
+        await yieldable_gather_results(get_key, keys_to_fetch)
         return results
 
     async def get_server_verify_key_v2_direct(
@@ -877,37 +731,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             keys.update(response_keys)
 
         return keys
-
-
-async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
-    """Waits for the key to become available, and then performs a verification
-
-    Args:
-        verify_request:
-
-    Raises:
-        SynapseError if there was a problem performing the verification
-    """
-    server_name = verify_request.server_name
-    with PreserveLoggingContext():
-        _, key_id, verify_key = await verify_request.key_ready
-
-    json_object = verify_request.json_object
-
-    try:
-        verify_signed_json(json_object, server_name, verify_key)
-    except SignatureVerifyException as e:
-        logger.debug(
-            "Error verifying signature for %s:%s:%s with key %s: %s",
-            server_name,
-            verify_key.alg,
-            verify_key.version,
-            encode_verify_key_base64(verify_key),
-            str(e),
-        )
-        raise SynapseError(
-            401,
-            "Invalid signature for server %s with key %s:%s: %s"
-            % (server_name, verify_key.alg, verify_key.version, str(e)),
-            Codes.UNAUTHORIZED,
-        )
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 949dcd4614..3fe496dcd3 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -137,11 +137,7 @@ class FederationBase:
         return deferreds
 
 
-class PduToCheckSig(
-    namedtuple(
-        "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
-    )
-):
+class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
     pass
 
 
@@ -184,7 +180,6 @@ def _check_sigs_on_pdus(
     pdus_to_check = [
         PduToCheckSig(
             pdu=p,
-            redacted_pdu_json=prune_event(p).get_pdu_json(),
             sender_domain=get_domain_from_id(p.sender),
             deferreds=[],
         )
@@ -195,13 +190,12 @@ def _check_sigs_on_pdus(
     # (except if its a 3pid invite, in which case it may be sent by any server)
     pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
 
-    more_deferreds = keyring.verify_json_objects_for_server(
+    more_deferreds = keyring.verify_events_for_server(
         [
             (
                 p.sender_domain,
-                p.redacted_pdu_json,
+                p.pdu,
                 p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
-                p.pdu.event_id,
             )
             for p in pdus_to_check_sender
         ]
@@ -230,13 +224,12 @@ def _check_sigs_on_pdus(
             if p.sender_domain != get_domain_from_id(p.pdu.event_id)
         ]
 
-        more_deferreds = keyring.verify_json_objects_for_server(
+        more_deferreds = keyring.verify_events_for_server(
             [
                 (
                     get_domain_from_id(p.pdu.event_id),
-                    p.redacted_pdu_json,
+                    p.pdu,
                     p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
-                    p.pdu.event_id,
                 )
                 for p in pdus_to_check_event_id
             ]
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index a5b6a61195..20812e706e 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -33,6 +33,7 @@ from typing import (
 )
 
 import attr
+import ijson
 from prometheus_client import Counter
 
 from twisted.internet import defer
@@ -55,11 +56,16 @@ from synapse.api.room_versions import (
 )
 from synapse.events import EventBase, builder
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable, preserve_fn
+from synapse.logging.context import (
+    get_thread_resource_usage,
+    make_deferred_yieldable,
+    preserve_fn,
+)
 from synapse.logging.utils import log_function
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.iterutils import batch_iter
 from synapse.util.retryutils import NotRetryingDestination
 
 if TYPE_CHECKING:
@@ -667,19 +673,37 @@ class FederationClient(FederationBase):
         async def send_request(destination) -> Dict[str, Any]:
             content = await self._do_send_join(destination, pdu)
 
-            logger.debug("Got content: %s", content)
+            # logger.debug("Got content: %s", content.getvalue())
 
-            state = [
-                event_from_pdu_json(p, room_version, outlier=True)
-                for p in content.get("state", [])
-            ]
+            # logger.info("send_join content: %d", len(content))
 
-            auth_chain = [
-                event_from_pdu_json(p, room_version, outlier=True)
-                for p in content.get("auth_chain", [])
-            ]
+            content.seek(0)
+
+            r = get_thread_resource_usage()
+            logger.info("Memory before state: %s", r.ru_maxrss)
+
+            state = []
+            for i, p in enumerate(ijson.items(content, "state.item")):
+                state.append(event_from_pdu_json(p, room_version, outlier=True))
+                if i % 1000 == 999:
+                    await self._clock.sleep(0)
+
+            r = get_thread_resource_usage()
+            logger.info("Memory after state: %s", r.ru_maxrss)
+
+            logger.info("Parsed state: %d", len(state))
+            content.seek(0)
+
+            auth_chain = []
+            for i, p in enumerate(ijson.items(content, "auth_chain.item")):
+                auth_chain.append(event_from_pdu_json(p, room_version, outlier=True))
+                if i % 1000 == 999:
+                    await self._clock.sleep(0)
+
+            r = get_thread_resource_usage()
+            logger.info("Memory after: %s", r.ru_maxrss)
 
-            pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
+            logger.info("Parsed auth chain: %d", len(auth_chain))
 
             create_event = None
             for e in state:
@@ -704,12 +728,19 @@ class FederationClient(FederationBase):
                     % (create_room_version,)
                 )
 
-            valid_pdus = await self._check_sigs_and_hash_and_fetch(
-                destination,
-                list(pdus.values()),
-                outlier=True,
-                room_version=room_version,
-            )
+            valid_pdus = []
+
+            for chunk in batch_iter(itertools.chain(state, auth_chain), 1000):
+                logger.info("Handling next _check_sigs_and_hash_and_fetch chunk")
+                new_valid_pdus = await self._check_sigs_and_hash_and_fetch(
+                    destination,
+                    chunk,
+                    outlier=True,
+                    room_version=room_version,
+                )
+                valid_pdus.extend(new_valid_pdus)
+
+            logger.info("_check_sigs_and_hash_and_fetch done")
 
             valid_pdus_map = {p.event_id: p for p in valid_pdus}
 
@@ -744,6 +775,8 @@ class FederationClient(FederationBase):
                     % (auth_chain_create_events,)
                 )
 
+            logger.info("Returning from send_join")
+
             return {
                 "state": signed_state,
                 "auth_chain": signed_auth,
@@ -769,6 +802,8 @@ class FederationClient(FederationBase):
             if not self._is_unknown_endpoint(e):
                 raise
 
+        raise NotImplementedError()
+
         logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
 
         resp = await self.transport_layer.send_join_v1(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index ada322a81e..9c0a105f36 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -244,7 +244,10 @@ class TransportLayerClient:
         path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
 
         response = await self.client.put_json(
-            destination=destination, path=path, data=content
+            destination=destination,
+            path=path,
+            data=content,
+            return_string_io=True,
         )
 
         return response
@@ -254,7 +257,10 @@ class TransportLayerClient:
         path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
 
         response = await self.client.put_json(
-            destination=destination, path=path, data=content
+            destination=destination,
+            path=path,
+            data=content,
+            return_string_io=True,
         )
 
         return response
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index de1b14cde3..4064a2b859 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -78,7 +78,7 @@ class DirectoryHandler(BaseHandler):
         # TODO(erikj): Add transactions.
         # TODO(erikj): Check if there is a current association.
         if not servers:
-            users = await self.state.get_current_users_in_room(room_id)
+            users = await self.store.get_users_in_room(room_id)
             servers = {get_domain_from_id(u) for u in users}
 
         if not servers:
@@ -270,7 +270,7 @@ class DirectoryHandler(BaseHandler):
                 Codes.NOT_FOUND,
             )
 
-        users = await self.state.get_current_users_in_room(room_id)
+        users = await self.store.get_users_in_room(room_id)
         extra_servers = {get_domain_from_id(u) for u in users}
         servers = set(extra_servers) | set(servers)
 
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d82144d7fa..f134f1e234 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -103,7 +103,7 @@ class EventStreamHandler(BaseHandler):
                     # Send down presence.
                     if event.state_key == auth_user_id:
                         # Send down presence for everyone in the room.
-                        users = await self.state.get_current_users_in_room(
+                        users = await self.store.get_users_in_room(
                             event.room_id
                         )  # type: Iterable[str]
                     else:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 9d867aaf4d..69055a14b3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1452,7 +1452,7 @@ class FederationHandler(BaseHandler):
         # room stuff after join currently doesn't work on workers.
         assert self.config.worker.worker_app is None
 
-        logger.debug("Joining %s to %s", joinee, room_id)
+        logger.info("Joining %s to %s", joinee, room_id)
 
         origin, event, room_version_obj = await self._make_and_verify_event(
             target_hosts,
@@ -1463,6 +1463,8 @@ class FederationHandler(BaseHandler):
             params={"ver": KNOWN_ROOM_VERSIONS},
         )
 
+        logger.info("make_join done from %s", origin)
+
         # This shouldn't happen, because the RoomMemberHandler has a
         # linearizer lock which only allows one operation per user per room
         # at a time - so this is just paranoia.
@@ -1482,10 +1484,13 @@ class FederationHandler(BaseHandler):
             except ValueError:
                 pass
 
+            logger.info("Sending join")
             ret = await self.federation_client.send_join(
                 host_list, event, room_version_obj
             )
 
+            logger.info("send join done")
+
             origin = ret["origin"]
             state = ret["state"]
             auth_chain = ret["auth_chain"]
@@ -1510,10 +1515,14 @@ class FederationHandler(BaseHandler):
                 room_version=room_version_obj,
             )
 
+            logger.info("Persisting auth true")
+
             max_stream_id = await self._persist_auth_tree(
                 origin, room_id, auth_chain, state, event, room_version_obj
             )
 
+            logger.info("Persisted auth true")
+
             # We wait here until this instance has seen the events come down
             # replication (if we're using replication) as the below uses caches.
             await self._replication.wait_for_stream_position(
@@ -2166,6 +2175,8 @@ class FederationHandler(BaseHandler):
             ctx = await self.state_handler.compute_event_context(e)
             events_to_context[e.event_id] = ctx
 
+        logger.info("Computed contexts")
+
         event_map = {
             e.event_id: e for e in itertools.chain(auth_events, state, [event])
         }
@@ -2207,6 +2218,8 @@ class FederationHandler(BaseHandler):
             else:
                 logger.info("Failed to find auth event %r", e_id)
 
+        logger.info("Got missing events")
+
         for e in itertools.chain(auth_events, state, [event]):
             auth_for_e = {
                 (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
@@ -2231,6 +2244,8 @@ class FederationHandler(BaseHandler):
                     raise
                 events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
 
+        logger.info("Authed events")
+
         await self.persist_events_and_notify(
             room_id,
             [
@@ -2239,10 +2254,14 @@ class FederationHandler(BaseHandler):
             ],
         )
 
+        logger.info("Persisted events")
+
         new_event_context = await self.state_handler.compute_event_context(
             event, old_state=state
         )
 
+        logger.info("Computed context")
+
         return await self.persist_events_and_notify(
             room_id, [(event, new_event_context)]
         )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 49f8aa25ea..393f17c3a3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -258,7 +258,7 @@ class MessageHandler:
                     "Getting joined members after leaving is not implemented"
                 )
 
-        users_with_profile = await self.state.get_current_users_in_room(room_id)
+        users_with_profile = await self.store.get_users_in_room_with_profiles(room_id)
 
         # If this is an AS, double check that they are allowed to see the members.
         # This can either be because the AS user is in the room or because there
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 5a888b7941..fb4823a5cc 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1327,7 +1327,7 @@ class RoomShutdownHandler:
             new_room_id = None
             logger.info("Shutting down room %r", room_id)
 
-        users = await self.state.get_current_users_in_room(room_id)
+        users = await self.store.get_users_in_room(room_id)
         kicked_users = []
         failed_to_kick_users = []
         for user_id in users:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index a9a3ee05c3..0fcc1532da 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1190,7 +1190,7 @@ class SyncHandler:
 
             # Step 1b, check for newly joined rooms
             for room_id in newly_joined_rooms:
-                joined_users = await self.state.get_current_users_in_room(room_id)
+                joined_users = await self.store.get_users_in_room(room_id)
                 newly_joined_or_invited_users.update(joined_users)
 
             # TODO: Check that these users are actually new, i.e. either they
@@ -1206,7 +1206,7 @@ class SyncHandler:
 
             # Now find users that we no longer track
             for room_id in newly_left_rooms:
-                left_users = await self.state.get_current_users_in_room(room_id)
+                left_users = await self.store.get_users_in_room(room_id)
                 newly_left_users.update(left_users)
 
             # Remove any users that we still share a room with.
@@ -1361,7 +1361,7 @@ class SyncHandler:
 
         extra_users_ids = set(newly_joined_or_invited_users)
         for room_id in newly_joined_rooms:
-            users = await self.state.get_current_users_in_room(room_id)
+            users = await self.store.get_users_in_room(room_id)
             extra_users_ids.update(users)
         extra_users_ids.discard(user.to_string())
 
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index bb837b7b19..6db1aece35 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -154,6 +154,7 @@ async def _handle_json_response(
     request: MatrixFederationRequest,
     response: IResponse,
     start_ms: int,
+    return_string_io=False,
 ) -> JsonDict:
     """
     Reads the JSON body of a response, with a timeout
@@ -175,12 +176,12 @@ async def _handle_json_response(
         d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE)
         d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
 
-        def parse(_len: int):
-            return json_decoder.decode(buf.getvalue())
+        await make_deferred_yieldable(d)
 
-        d.addCallback(parse)
-
-        body = await make_deferred_yieldable(d)
+        if return_string_io:
+            body = buf
+        else:
+            body = json_decoder.decode(buf.getvalue())
     except BodyExceededMaxSize as e:
         # The response was too big.
         logger.warning(
@@ -225,12 +226,13 @@ async def _handle_json_response(
     time_taken_secs = reactor.seconds() - start_ms / 1000
 
     logger.info(
-        "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
+        "{%s} [%s] Completed request: %d %s in %.2f secs got %dB - %s %s",
         request.txn_id,
         request.destination,
         response.code,
         response.phrase.decode("ascii", errors="replace"),
         time_taken_secs,
+        len(buf.getvalue()),
         request.method,
         request.uri.decode("ascii"),
     )
@@ -683,6 +685,7 @@ class MatrixFederationHttpClient:
         ignore_backoff: bool = False,
         backoff_on_404: bool = False,
         try_trailing_slash_on_400: bool = False,
+        return_string_io=False,
     ) -> Union[JsonDict, list]:
         """Sends the specified json data using PUT
 
@@ -757,7 +760,12 @@ class MatrixFederationHttpClient:
             _sec_timeout = self.default_timeout
 
         body = await _handle_json_response(
-            self.reactor, _sec_timeout, request, response, start_ms
+            self.reactor,
+            _sec_timeout,
+            request,
+            response,
+            start_ms,
+            return_string_io=return_string_io,
         )
 
         return body
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 31b7b3c256..c841363b1e 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -12,12 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import ctypes
+import ctypes.util
 import functools
 import gc
 import itertools
 import logging
 import os
 import platform
+import re
 import threading
 import time
 from typing import Callable, Dict, Iterable, Optional, Tuple, Union
@@ -535,6 +538,13 @@ class ReactorLastSeenMetric:
 
 REGISTRY.register(ReactorLastSeenMetric())
 
+# The minimum time in seconds between GCs for each generation, regardless of the current GC
+# thresholds and counts.
+MIN_TIME_BETWEEN_GCS = [1, 10, 30]
+
+# The time in seconds of the last time we did a GC for each generation.
+_last_gc = [0, 0, 0]
+
 
 def runUntilCurrentTimer(reactor, func):
     @functools.wraps(func)
@@ -575,11 +585,16 @@ def runUntilCurrentTimer(reactor, func):
             return ret
 
         # Check if we need to do a manual GC (since its been disabled), and do
-        # one if necessary.
+        # one if necessary. Note we go in reverse order as e.g. a gen 1 GC may
+        # promote an object into gen 2, and we don't want to handle the same
+        # object multiple times.
         threshold = gc.get_threshold()
         counts = gc.get_count()
         for i in (2, 1, 0):
-            if threshold[i] < counts[i]:
+            # We check if we need to do one based on a straightforward
+            # comparison between the threshold and count. We also do an extra
+            # check to make sure that we don't a GC too often.
+            if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]:
                 if i == 0:
                     logger.debug("Collecting gc %d", i)
                 else:
@@ -589,6 +604,8 @@ def runUntilCurrentTimer(reactor, func):
                 unreachable = gc.collect(i)
                 end = time.time()
 
+                _last_gc[i] = int(end)
+
                 gc_time.labels(i).observe(end - start)
                 gc_unreachable.labels(i).set(unreachable)
 
@@ -597,6 +614,163 @@ def runUntilCurrentTimer(reactor, func):
     return f
 
 
+def _setup_jemalloc_stats():
+    """Checks to see if jemalloc is loaded, and hooks up a collector to record
+    statistics exposed by jemalloc.
+    """
+
+    # Try to find the loaded jemalloc shared library, if any. We need to
+    # introspect into what is loaded, rather than loading whatever is on the
+    # path, as if we load a *different* jemalloc version things will seg fault.
+    pid = os.getpid()
+
+    # We're looking for a path at the end of the line that includes
+    # "libjemalloc".
+    regex = re.compile(r"/\S+/libjemalloc.*$")
+
+    jemalloc_path = None
+    with open(f"/proc/{pid}/maps") as f:
+        for line in f.readlines():
+            match = regex.search(line.strip())
+            if match:
+                jemalloc_path = match.group()
+
+    if not jemalloc_path:
+        # No loaded jemalloc was found.
+        return
+
+    jemalloc = ctypes.CDLL(jemalloc_path)
+
+    def _mallctl(
+        name: str, read: bool = True, write: Optional[int] = None
+    ) -> Optional[int]:
+        """Wrapper around `mallctl` for reading and writing integers to
+        jemalloc.
+
+        Args:
+            name: The name of the option to read from/write to.
+            read: Whether to try and read the value.
+            write: The value to write, if given.
+
+        Returns:
+            The value read if `read` is True, otherwise None.
+
+        Raises:
+            An exception if `mallctl` returns a non-zero error code.
+        """
+
+        input_var = None
+        input_var_ref = None
+        input_len_ref = None
+        if read:
+            input_var = ctypes.c_size_t(0)
+            input_len = ctypes.c_size_t(ctypes.sizeof(input_var))
+
+            input_var_ref = ctypes.byref(input_var)
+            input_len_ref = ctypes.byref(input_len)
+
+        write_var_ref = None
+        write_len = ctypes.c_size_t(0)
+        if write is not None:
+            write_var = ctypes.c_size_t(write)
+            write_len = ctypes.c_size_t(ctypes.sizeof(write_var))
+
+            write_var_ref = ctypes.byref(write_var)
+
+        # The interface is:
+        #
+        #   int mallctl(
+        #       const char *name,
+        #       void *oldp,
+        #       size_t *oldlenp,
+        #       void *newp,
+        #       size_t newlen
+        #   )
+        #
+        # Where oldp/oldlenp is a buffer where the old value will be written to
+        # (if not null), and newp/newlen is the buffer with the new value to set
+        # (if not null). Note that they're all references *except* newlen.
+        result = jemalloc.mallctl(
+            name.encode("ascii"),
+            input_var_ref,
+            input_len_ref,
+            write_var_ref,
+            write_len,
+        )
+
+        if result != 0:
+            raise Exception("Failed to call mallctl")
+
+        if input_var is None:
+            return None
+
+        return input_var.value
+
+    def _jemalloc_refresh_stats() -> None:
+        """Request that jemalloc updates its internal statistics. This needs to
+        be called before querying for stats, otherwise it will return stale
+        values.
+        """
+        try:
+            _mallctl("epoch", read=False, write=1)
+        except Exception:
+            pass
+
+    class JemallocCollector:
+        """Metrics for internal jemalloc stats."""
+
+        def collect(self):
+            _jemalloc_refresh_stats()
+
+            g = GaugeMetricFamily(
+                "jemalloc_stats_app_memory",
+                "The stats reported by jemalloc",
+                labels=["type"],
+            )
+
+            # Read the relevant global stats from jemalloc. Note that these may
+            # not be accurate if python is configured to use its internal small
+            # object allocator (which is on by default, disable by setting the
+            # env `PYTHONMALLOC=malloc`).
+            #
+            # See the jemalloc manpage for details about what each value means,
+            # roughly:
+            #   - allocated ─ Total number of bytes allocated by the app
+            #   - active ─ Total number of bytes in active pages allocated by
+            #     the application, this is bigger than `allocated`.
+            #   - resident ─ Maximum number of bytes in physically resident data
+            #     pages mapped by the allocator, comprising all pages dedicated
+            #     to allocator metadata, pages backing active allocations, and
+            #     unused dirty pages. This is bigger than `active`.
+            #   - mapped ─ Total number of bytes in active extents mapped by the
+            #     allocator.
+            #   - metadata ─ Total number of bytes dedicated to jemalloc
+            #     metadata.
+            for t in (
+                "allocated",
+                "active",
+                "resident",
+                "mapped",
+                "metadata",
+            ):
+                try:
+                    value = _mallctl(f"stats.{t}")
+                except Exception:
+                    # There was an error fetching the value, skip.
+                    continue
+
+                g.add_metric([t], value=value)
+
+            yield g
+
+    REGISTRY.register(JemallocCollector())
+
+
+try:
+    _setup_jemalloc_stats()
+except Exception:
+    logger.info("Failed to setup collector to record jemalloc stats.")
+
 try:
     # Ensure the reactor has all the attributes we expect
     reactor.seconds  # type: ignore
@@ -615,6 +789,7 @@ try:
 except AttributeError:
     pass
 
+
 __all__ = [
     "MetricsResource",
     "generate_latest",
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
index 1a3b051e3c..b402f82810 100644
--- a/synapse/replication/tcp/external_cache.py
+++ b/synapse/replication/tcp/external_cache.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Any, Optional
 
-from prometheus_client import Counter
+from prometheus_client import Counter, Histogram
 
 from synapse.logging.context import make_deferred_yieldable
 from synapse.util import json_decoder, json_encoder
@@ -35,6 +35,20 @@ get_counter = Counter(
     labelnames=["cache_name", "hit"],
 )
 
+response_timer = Histogram(
+    "synapse_external_cache_response_time_seconds",
+    "Time taken to get a response from Redis for a cache get/set request",
+    labelnames=["method"],
+    buckets=(
+        0.001,
+        0.002,
+        0.005,
+        0.01,
+        0.02,
+        0.05,
+    ),
+)
+
 
 logger = logging.getLogger(__name__)
 
@@ -72,13 +86,14 @@ class ExternalCache:
 
         logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
 
-        return await make_deferred_yieldable(
-            self._redis_connection.set(
-                self._get_redis_key(cache_name, key),
-                encoded_value,
-                pexpire=expiry_ms,
+        with response_timer.labels("set").time():
+            return await make_deferred_yieldable(
+                self._redis_connection.set(
+                    self._get_redis_key(cache_name, key),
+                    encoded_value,
+                    pexpire=expiry_ms,
+                )
             )
-        )
 
     async def get(self, cache_name: str, key: str) -> Optional[Any]:
         """Look up a key/value in the named cache."""
@@ -86,9 +101,10 @@ class ExternalCache:
         if self._redis_connection is None:
             return None
 
-        result = await make_deferred_yieldable(
-            self._redis_connection.get(self._get_redis_key(cache_name, key))
-        )
+        with response_timer.labels("get").time():
+            result = await make_deferred_yieldable(
+                self._redis_connection.get(self._get_redis_key(cache_name, key))
+            )
 
         logger.debug("Got cache result %s %s: %r", cache_name, key, result)
 
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f648678b09..e19e9ef5c7 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
 from synapse.util import json_decoder
+from synapse.util.async_helpers import yieldable_gather_results
 
 logger = logging.getLogger(__name__)
 
@@ -213,7 +214,13 @@ class RemoteKey(DirectServeJsonResource):
         # If there is a cache miss, request the missing keys, then recurse (and
         # ensure the result is sent).
         if cache_misses and query_remote_on_cache_miss:
-            await self.fetcher.get_keys(cache_misses)
+            await yieldable_gather_results(
+                lambda t: self.fetcher.get_keys(*t),
+                (
+                    (server_name, list(keys), 0)
+                    for server_name, keys in cache_misses.items()
+                ),
+            )
             await self.query_keys(request, query, query_remote_on_cache_miss=False)
         else:
             signed_keys = []
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index b3bd92d37c..a1770f620e 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -213,19 +213,23 @@ class StateHandler:
         return ret.state
 
     async def get_current_users_in_room(
-        self, room_id: str, latest_event_ids: Optional[List[str]] = None
+        self, room_id: str, latest_event_ids: List[str]
     ) -> Dict[str, ProfileInfo]:
         """
         Get the users who are currently in a room.
 
+        Note: This is much slower than using the equivalent method
+        `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
+        so this should only be used when wanting the users at a particular point
+        in the room.
+
         Args:
             room_id: The ID of the room.
             latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
         Returns:
             Dictionary of user IDs to their profileinfo.
         """
-        if not latest_event_ids:
-            latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+
         assert latest_event_ids is not None
 
         logger.debug("calling resolve_state_groups from get_current_users_in_room")
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6b68d8720c..3d98d3f5f8 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -69,6 +69,7 @@ class SQLBaseStore(metaclass=ABCMeta):
             self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
 
         self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
+        self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,))
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
         self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2a8532f8c1..5fc3bb5a7d 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -205,8 +205,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
             sql = """
-                SELECT user_id, display_name, avatar_url FROM room_memberships
-                WHERE room_id = ? AND membership = ?
+                SELECT state_key, display_name, avatar_url FROM room_memberships as m
+                INNER JOIN current_state_events as c
+                ON m.event_id = c.event_id
+                AND m.room_id = c.room_id
+                AND m.user_id = c.state_key
+                WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
             """
             txn.execute(sql, (room_id, Membership.JOIN))
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 7a082fdd21..a6bfb4902a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -142,8 +142,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             batch_size (int): Maximum number of state events to process
                 per cycle.
         """
-        state = self.hs.get_state_handler()
-
         # If we don't have progress filed, delete everything.
         if not progress:
             await self.delete_all_from_user_dir()
@@ -197,7 +195,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                     room_id
                 )
 
-                users_with_profile = await state.get_current_users_in_room(room_id)
+                users_with_profile = await self.get_users_in_room_with_profiles(room_id)
                 user_ids = set(users_with_profile)
 
                 # Update each user in the user directory.