diff options
29 files changed, 583 insertions, 467 deletions
diff --git a/changelog.d/9882.misc b/changelog.d/9882.misc new file mode 100644 index 0000000000..facfa31f38 --- /dev/null +++ b/changelog.d/9882.misc @@ -0,0 +1 @@ +Export jemalloc stats to Prometheus if it is being used. diff --git a/changelog.d/9902.feature b/changelog.d/9902.feature new file mode 100644 index 0000000000..4d9f324d4e --- /dev/null +++ b/changelog.d/9902.feature @@ -0,0 +1 @@ +Add limits to how often Synapse will GC, ensuring that large servers do not end up GC thrashing if `gc_thresholds` has not been correctly set. diff --git a/changelog.d/9904.misc b/changelog.d/9904.misc new file mode 100644 index 0000000000..3db1e625ae --- /dev/null +++ b/changelog.d/9904.misc @@ -0,0 +1 @@ +Time response time for external cache requests. diff --git a/changelog.d/9910.bugfix b/changelog.d/9910.bugfix new file mode 100644 index 0000000000..06d523fd46 --- /dev/null +++ b/changelog.d/9910.bugfix @@ -0,0 +1 @@ +Fix bug where user directory could get out of sync if room visibility and membership changed in quick succession. diff --git a/changelog.d/9910.misc b/changelog.d/9910.feature index 54165cce18..54165cce18 100644 --- a/changelog.d/9910.misc +++ b/changelog.d/9910.feature diff --git a/changelog.d/9911.doc b/changelog.d/9911.doc new file mode 100644 index 0000000000..f7fd9f1ba9 --- /dev/null +++ b/changelog.d/9911.doc @@ -0,0 +1 @@ +Add `port` argument to the Postgres database sample config section. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index e0350279ad..9e22696170 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -152,6 +152,14 @@ presence: # #gc_thresholds: [700, 10, 10] +# The minimum time in seconds between each GC for a generation, regardless of +# the GC thresholds. This ensures that we don't do GC too frequently. +# +# A value of `[1, 10, 30]` indicates that a second must pass between consecutive +# generation 0 GCs, etc. +# +# gc_min_seconds_between: [1, 10, 30] + # Set the limit on the returned events in the timeline in the get # and sync operations. The default value is 100. -1 means no upper limit. # @@ -810,6 +818,7 @@ caches: # password: secretpassword # database: synapse # host: localhost +# port: 5432 # cp_min: 5 # cp_max: 10 # diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1a15ceee81..a3fe9a3f38 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -455,6 +455,9 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + if config.server.gc_seconds: + synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds + hs = GenericWorkerServer( config.server_name, config=config, diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8e78134bbe..6a823da10d 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -342,6 +342,9 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts + if config.server.gc_seconds: + synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds + hs = SynapseHomeServer( config.server_name, config=config, diff --git a/synapse/config/database.py b/synapse/config/database.py index 79a02706b4..c76ef1e1de 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -58,6 +58,7 @@ DEFAULT_CONFIG = """\ # password: secretpassword # database: synapse # host: localhost +# port: 5432 # cp_min: 5 # cp_max: 10 # diff --git a/synapse/config/server.py b/synapse/config/server.py index 21ca7b33e3..ca1c9711f8 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -572,6 +572,7 @@ class ServerConfig(Config): _warn_if_webclient_configured(self.listeners) self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) + self.gc_seconds = read_gc_thresholds(config.get("gc_min_seconds_between", None)) @attr.s class LimitRemoteRoomsConfig: @@ -917,6 +918,14 @@ class ServerConfig(Config): # #gc_thresholds: [700, 10, 10] + # The minimum time in seconds between each GC for a generation, regardless of + # the GC thresholds. This ensures that we don't do GC too frequently. + # + # A value of `[1, 10, 30]` indicates that a second must pass between consecutive + # generation 0 GCs, etc. + # + # gc_min_seconds_between: [1, 10, 30] + # Set the limit on the returned events in the timeline in the get # and sync operations. The default value is 100. -1 means no upper limit. # diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 5f18ef7748..a8c8df2bad 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -16,8 +16,7 @@ import abc import logging import urllib -from collections import defaultdict -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple import attr from signedjson.key import ( @@ -42,17 +41,14 @@ from synapse.api.errors import ( SynapseError, ) from synapse.config.key import TrustedKeyServer -from synapse.logging.context import ( - PreserveLoggingContext, - make_deferred_yieldable, - preserve_fn, - run_in_background, -) +from synapse.events import EventBase +from synapse.events.utils import prune_event_dict +from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict from synapse.util import unwrapFirstError -from synapse.util.async_helpers import yieldable_gather_results -from synapse.util.metrics import Measure +from synapse.util.async_helpers import Linearizer, yieldable_gather_results from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: @@ -74,8 +70,6 @@ class VerifyJsonRequest: minimum_valid_until_ts: time at which we require the signing key to be valid. (0 implies we don't care) - request_name: The name of the request. - key_ids: The set of key_ids to that could be used to verify the JSON object key_ready (Deferred[str, str, nacl.signing.VerifyKey]): @@ -88,20 +82,94 @@ class VerifyJsonRequest: """ server_name = attr.ib(type=str) - json_object = attr.ib(type=JsonDict) + json_object_callback = attr.ib(type=Callable[[], JsonDict]) minimum_valid_until_ts = attr.ib(type=int) - request_name = attr.ib(type=str) - key_ids = attr.ib(init=False, type=List[str]) - key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred) + key_ids = attr.ib(type=List[str]) + + @staticmethod + def from_json_object( + server_name: str, minimum_valid_until_ms: int, json_object: JsonDict + ): + key_ids = signature_ids(json_object, server_name) + return VerifyJsonRequest( + server_name, lambda: json_object, minimum_valid_until_ms, key_ids + ) - def __attrs_post_init__(self): - self.key_ids = signature_ids(self.json_object, self.server_name) + @staticmethod + def from_event( + server_name: str, + minimum_valid_until_ms: int, + event: EventBase, + ): + key_ids = list(event.signatures.get(server_name, [])) + return VerifyJsonRequest( + server_name, + lambda: prune_event_dict(event.room_version, event.get_pdu_json()), + minimum_valid_until_ms, + key_ids, + ) class KeyLookupError(ValueError): pass +@attr.s(slots=True) +class _QueueValue: + server_name = attr.ib(type=str) + minimum_valid_until_ts = attr.ib(type=int) + key_ids = attr.ib(type=List[str]) + + +class _Queue: + def __init__(self, name, clock, process_items): + self._name = name + self._clock = clock + self._is_processing = False + self._next_values = [] + + self.process_items = process_items + + async def add_to_queue(self, value: _QueueValue) -> Dict[str, FetchKeyResult]: + d = defer.Deferred() + self._next_values.append((value, d)) + + if self._is_processing: + return await d + + run_as_background_process(self._name, self._unsafe_process) + + return await d + + async def _unsafe_process(self): + # We purposefully defer to the next loop. + await self._clock.sleep(0) + + try: + if self._is_processing: + return + + self._is_processing = True + + while self._next_values: + next_values = self._next_values + self._next_values = [] + + try: + values = [value for value, _ in next_values] + results = await self.process_items(values) + + for value, deferred in next_values: + deferred.callback(results.get(value.server_name, {})) + + except Exception as e: + for _, deferred in next_values: + deferred.errback(e) + + finally: + self._is_processing = False + + class Keyring: def __init__( self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None @@ -116,12 +184,7 @@ class Keyring: ) self._key_fetchers = key_fetchers - # map from server name to Deferred. Has an entry for each server with - # an ongoing key download; the Deferred completes once the download - # completes. - # - # These are regular, logcontext-agnostic Deferreds. - self.key_downloads = {} # type: Dict[str, defer.Deferred] + self._server_queue = Linearizer("keyring_server") def verify_json_for_server( self, @@ -130,365 +193,150 @@ class Keyring: validity_time: int, request_name: str, ) -> defer.Deferred: - """Verify that a JSON object has been signed by a given server - - Args: - server_name: name of the server which must have signed this object - - json_object: object to be checked - - validity_time: timestamp at which we require the signing key to - be valid. (0 implies we don't care) - - request_name: an identifier for this json object (eg, an event id) - for logging. - - Returns: - Deferred[None]: completes if the the object was correctly signed, otherwise - errbacks with an error - """ - req = VerifyJsonRequest(server_name, json_object, validity_time, request_name) - requests = (req,) - return make_deferred_yieldable(self._verify_objects(requests)[0]) + request = VerifyJsonRequest.from_json_object( + server_name, + validity_time, + json_object, + ) + return defer.ensureDeferred(self._verify_object(request)) def verify_json_objects_for_server( self, server_and_json: Iterable[Tuple[str, dict, int, str]] ) -> List[defer.Deferred]: - """Bulk verifies signatures of json objects, bulk fetching keys as - necessary. - - Args: - server_and_json: - Iterable of (server_name, json_object, validity_time, request_name) - tuples. - - validity_time is a timestamp at which the signing key must be - valid. - - request_name is an identifier for this json object (eg, an event id) - for logging. - - Returns: - List<Deferred[None]>: for each input triplet, a deferred indicating success - or failure to verify each json object's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. - """ - return self._verify_objects( - VerifyJsonRequest(server_name, json_object, validity_time, request_name) + return [ + defer.ensureDeferred( + self._verify_object( + VerifyJsonRequest.from_json_object( + server_name, + validity_time, + json_object, + ) + ) + ) for server_name, json_object, validity_time, request_name in server_and_json - ) + ] - def _verify_objects( - self, verify_requests: Iterable[VerifyJsonRequest] + def verify_events_for_server( + self, server_and_json: Iterable[Tuple[str, EventBase, int]] ) -> List[defer.Deferred]: - """Does the work of verify_json_[objects_]for_server - - - Args: - verify_requests: Iterable of verification requests. - - Returns: - List<Deferred[None]>: for each input item, a deferred indicating success - or failure to verify each json object's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. - """ - # a list of VerifyJsonRequests which are awaiting a key lookup - key_lookups = [] - handle = preserve_fn(_handle_key_deferred) - - def process(verify_request: VerifyJsonRequest) -> defer.Deferred: - """Process an entry in the request list - - Adds a key request to key_lookups, and returns a deferred which - will complete or fail (in the sentinel context) when verification completes. - """ - if not verify_request.key_ids: - return defer.fail( - SynapseError( - 400, - "Not signed by %s" % (verify_request.server_name,), - Codes.UNAUTHORIZED, + return [ + defer.ensureDeferred( + self._verify_object( + VerifyJsonRequest.from_event( + server_name, + validity_time, + event, ) ) - - logger.debug( - "Verifying %s for %s with key_ids %s, min_validity %i", - verify_request.request_name, - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ) - - # add the key request to the queue, but don't start it off yet. - key_lookups.append(verify_request) - - # now run _handle_key_deferred, which will wait for the key request - # to complete and then do the verification. - # - # We want _handle_key_request to log to the right context, so we - # wrap it with preserve_fn (aka run_in_background) - return handle(verify_request) - - results = [process(r) for r in verify_requests] - - if key_lookups: - run_in_background(self._start_key_lookups, key_lookups) - - return results - - async def _start_key_lookups( - self, verify_requests: List[VerifyJsonRequest] - ) -> None: - """Sets off the key fetches for each verify request - - Once each fetch completes, verify_request.key_ready will be resolved. - - Args: - verify_requests: - """ - - try: - # map from server name to a set of outstanding request ids - server_to_request_ids = {} # type: Dict[str, Set[int]] - - for verify_request in verify_requests: - server_name = verify_request.server_name - request_id = id(verify_request) - server_to_request_ids.setdefault(server_name, set()).add(request_id) - - # Wait for any previous lookups to complete before proceeding. - await self.wait_for_previous_lookups(server_to_request_ids.keys()) - - # take out a lock on each of the servers by sticking a Deferred in - # key_downloads - for server_name in server_to_request_ids.keys(): - self.key_downloads[server_name] = defer.Deferred() - logger.debug("Got key lookup lock on %s", server_name) - - # When we've finished fetching all the keys for a given server_name, - # drop the lock by resolving the deferred in key_downloads. - def drop_server_lock(server_name): - d = self.key_downloads.pop(server_name) - d.callback(None) - - def lookup_done(res, verify_request): - server_name = verify_request.server_name - server_requests = server_to_request_ids[server_name] - server_requests.remove(id(verify_request)) - - # if there are no more requests for this server, we can drop the lock. - if not server_requests: - logger.debug("Releasing key lookup lock on %s", server_name) - drop_server_lock(server_name) - - return res - - for verify_request in verify_requests: - verify_request.key_ready.addBoth(lookup_done, verify_request) - - # Actually start fetching keys. - self._get_server_verify_keys(verify_requests) - except Exception: - logger.exception("Error starting key lookups") - - async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None: - """Waits for any previous key lookups for the given servers to finish. - - Args: - server_names: list of servers which we want to look up - - Returns: - Resolves once all key lookups for the given servers have - completed. Follows the synapse rules of logcontext preservation. - """ - loop_count = 1 - while True: - wait_on = [ - (server_name, self.key_downloads[server_name]) - for server_name in server_names - if server_name in self.key_downloads - ] - if not wait_on: - break - logger.info( - "Waiting for existing lookups for %s to complete [loop %i]", - [w[0] for w in wait_on], - loop_count, ) - with PreserveLoggingContext(): - await defer.DeferredList((w[1] for w in wait_on)) - - loop_count += 1 - - def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None: - """Tries to find at least one key for each verify request - - For each verify_request, verify_request.key_ready is called back with - params (server_name, key_id, VerifyKey) if a key is found, or errbacked - with a SynapseError if none of the keys are found. + for server_name, event, validity_time in server_and_json + ] + + async def _verify_object(self, verify_request: VerifyJsonRequest): + # TODO: Use a batching thing. + with (await self._server_queue.queue(verify_request.server_name)): + found_keys: Dict[str, FetchKeyResult] = {} + missing_key_ids = set(verify_request.key_ids) + for fetcher in self._key_fetchers: + if not missing_key_ids: + break + + keys = await fetcher.get_keys( + verify_request.server_name, + list(missing_key_ids), + verify_request.minimum_valid_until_ts, + ) - Args: - verify_requests: list of verify requests - """ + for key_id, key in keys.items(): + if not key: + continue - remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} + if key.valid_until_ts < verify_request.minimum_valid_until_ts: + continue - async def do_iterations(): - try: - with Measure(self.clock, "get_server_verify_keys"): - for f in self._key_fetchers: - if not remaining_requests: - return - await self._attempt_key_fetches_with_fetcher( - f, remaining_requests - ) - - # look for any requests which weren't satisfied - while remaining_requests: - verify_request = remaining_requests.pop() - rq_str = ( - "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ) - ) - - # If we run the errback immediately, it may cancel our - # loggingcontext while we are still in it, so instead we - # schedule it for the next time round the reactor. - # - # (this also ensures that we don't get a stack overflow if we - # has a massive queue of lookups waiting for this server). - self.clock.call_later( - 0, - verify_request.key_ready.errback, - SynapseError( - 401, - "Failed to find any key to satisfy %s" % (rq_str,), - Codes.UNAUTHORIZED, - ), - ) - except Exception as err: - # we don't really expect to get here, because any errors should already - # have been caught and logged. But if we do, let's log the error and make - # sure that all of the deferreds are resolved. - logger.error("Unexpected error in _get_server_verify_keys: %s", err) - with PreserveLoggingContext(): - for verify_request in remaining_requests: - if not verify_request.key_ready.called: - verify_request.key_ready.errback(err) - - run_in_background(do_iterations) - - async def _attempt_key_fetches_with_fetcher( - self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest] - ): - """Use a key fetcher to attempt to satisfy some key requests + existing_key = found_keys.get(key_id) + if existing_key: + if key.valid_until_ts <= existing_key.valid_until_ts: + continue - Args: - fetcher: fetcher to use to fetch the keys - remaining_requests: outstanding key requests. - Any successfully-completed requests will be removed from the list. - """ - # The keys to fetch. - # server_name -> key_id -> min_valid_ts - missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]] + found_keys[key_id] = key - for verify_request in remaining_requests: - # any completed requests should already have been removed - assert not verify_request.key_ready.called - keys_for_server = missing_keys[verify_request.server_name] + missing_key_ids.difference_update(found_keys) - for key_id in verify_request.key_ids: - # If we have several requests for the same key, then we only need to - # request that key once, but we should do so with the greatest - # min_valid_until_ts of the requests, so that we can satisfy all of - # the requests. - keys_for_server[key_id] = max( - keys_for_server.get(key_id, -1), - verify_request.minimum_valid_until_ts, + if missing_key_ids: + raise SynapseError( + 400, + "Missing keys for %s: %s" + % (verify_request.server_name, missing_key_ids), + Codes.UNAUTHORIZED, ) - results = await fetcher.get_keys(missing_keys) - - completed = [] - for verify_request in remaining_requests: - server_name = verify_request.server_name - - # see if any of the keys we got this time are sufficient to - # complete this VerifyJsonRequest. - result_keys = results.get(server_name, {}) for key_id in verify_request.key_ids: - fetch_key_result = result_keys.get(key_id) - if not fetch_key_result: - # we didn't get a result for this key - continue - - if ( - fetch_key_result.valid_until_ts - < verify_request.minimum_valid_until_ts - ): - # key was not valid at this point - continue - - # we have a valid key for this request. If we run the callback - # immediately, it may cancel our loggingcontext while we are still in - # it, so instead we schedule it for the next time round the reactor. - # - # (this also ensures that we don't get a stack overflow if we had - # a massive queue of lookups waiting for this server). - logger.debug( - "Found key %s:%s for %s", - server_name, - key_id, - verify_request.request_name, - ) - self.clock.call_later( - 0, - verify_request.key_ready.callback, - (server_name, key_id, fetch_key_result.verify_key), - ) - completed.append(verify_request) - break - - remaining_requests.difference_update(completed) + verify_key = found_keys[key_id].verify_key + try: + json_object = verify_request.json_object_callback() + verify_signed_json( + json_object, + verify_request.server_name, + verify_key, + ) + except SignatureVerifyException as e: + logger.debug( + "Error verifying signature for %s:%s:%s with key %s: %s", + verify_request.server_name, + verify_key.alg, + verify_key.version, + encode_verify_key_base64(verify_key), + str(e), + ) + raise SynapseError( + 401, + "Invalid signature for server %s with key %s:%s: %s" + % ( + verify_request.server_name, + verify_key.alg, + verify_key.version, + str(e), + ), + Codes.UNAUTHORIZED, + ) class KeyFetcher(metaclass=abc.ABCMeta): - @abc.abstractmethod + def __init__(self, hs: "HomeServer"): + self._queue = _Queue(self.__class__.__name__, hs.get_clock(), self._fetch_keys) + async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] - ) -> Dict[str, Dict[str, FetchKeyResult]]: - """ - Args: - keys_to_fetch: - the keys to be fetched. server_name -> key_id -> min_valid_ts + self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + return await self._queue.add_to_queue( + _QueueValue( + server_name=server_name, + key_ids=key_ids, + minimum_valid_until_ts=minimum_valid_until_ts, + ) + ) - Returns: - Map from server_name -> key_id -> FetchKeyResult - """ - raise NotImplementedError + @abc.abstractmethod + async def _fetch_keys( + self, keys_to_fetch: List[_QueueValue] + ) -> Dict[str, Dict[str, FetchKeyResult]]: + pass class StoreKeyFetcher(KeyFetcher): """KeyFetcher impl which fetches keys from our data store""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + super().__init__(hs) - async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] - ) -> Dict[str, Dict[str, FetchKeyResult]]: - """see KeyFetcher.get_keys""" + self.store = hs.get_datastore() + async def _fetch_keys(self, keys_to_fetch: List[_QueueValue]): key_ids_to_fetch = ( - (server_name, key_id) - for server_name, keys_for_server in keys_to_fetch.items() - for key_id in keys_for_server.keys() + (queue_value.server_name, key_id) + for queue_value in keys_to_fetch + for key_id in queue_value.key_ids ) res = await self.store.get_server_verify_keys(key_ids_to_fetch) @@ -500,6 +348,8 @@ class StoreKeyFetcher(KeyFetcher): class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.store = hs.get_datastore() self.config = hs.config @@ -607,10 +457,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): self.client = hs.get_federation_http_client() self.key_servers = self.config.key_servers - async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] + async def _fetch_keys( + self, keys_to_fetch: List[_QueueValue] ) -> Dict[str, Dict[str, FetchKeyResult]]: - """see KeyFetcher.get_keys""" + """see KeyFetcher._fetch_keys""" async def get_key(key_server: TrustedKeyServer) -> Dict: try: @@ -646,12 +496,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return union_of_keys async def get_server_verify_key_v2_indirect( - self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer + self, keys_to_fetch: List[_QueueValue], key_server: TrustedKeyServer ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: keys_to_fetch: - the keys to be fetched. server_name -> key_id -> min_valid_ts + the keys to be fetched. key_server: notary server to query for the keys @@ -665,7 +515,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): perspective_name = key_server.server_name logger.info( "Requesting keys %s from notary server %s", - keys_to_fetch.items(), + keys_to_fetch, perspective_name, ) @@ -675,11 +525,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): path="/_matrix/key/v2/query", data={ "server_keys": { - server_name: { - key_id: {"minimum_valid_until_ts": min_valid_ts} - for key_id, min_valid_ts in server_keys.items() + queue_value.server_name: { + key_id: { + "minimum_valid_until_ts": queue_value.minimum_valid_until_ts, + } + for key_id in queue_value.key_ids } - for server_name, server_keys in keys_to_fetch.items() + for queue_value in keys_to_fetch } }, ) @@ -779,8 +631,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher): self.clock = hs.get_clock() self.client = hs.get_federation_http_client() - async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] + async def _fetch_keys( + self, keys_to_fetch: List[_QueueValue] ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: @@ -793,8 +645,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher): results = {} - async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None: - server_name, key_ids = key_to_fetch_item + async def get_key(key_to_fetch_item: _QueueValue) -> None: + server_name = key_to_fetch_item.server_name + key_ids = key_to_fetch_item.key_ids + try: keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) results[server_name] = keys @@ -805,7 +659,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except Exception: logger.exception("Error getting keys %s from %s", key_ids, server_name) - await yieldable_gather_results(get_key, keys_to_fetch.items()) + await yieldable_gather_results(get_key, keys_to_fetch) return results async def get_server_verify_key_v2_direct( @@ -877,37 +731,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher): keys.update(response_keys) return keys - - -async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None: - """Waits for the key to become available, and then performs a verification - - Args: - verify_request: - - Raises: - SynapseError if there was a problem performing the verification - """ - server_name = verify_request.server_name - with PreserveLoggingContext(): - _, key_id, verify_key = await verify_request.key_ready - - json_object = verify_request.json_object - - try: - verify_signed_json(json_object, server_name, verify_key) - except SignatureVerifyException as e: - logger.debug( - "Error verifying signature for %s:%s:%s with key %s: %s", - server_name, - verify_key.alg, - verify_key.version, - encode_verify_key_base64(verify_key), - str(e), - ) - raise SynapseError( - 401, - "Invalid signature for server %s with key %s:%s: %s" - % (server_name, verify_key.alg, verify_key.version, str(e)), - Codes.UNAUTHORIZED, - ) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 949dcd4614..3fe496dcd3 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -137,11 +137,7 @@ class FederationBase: return deferreds -class PduToCheckSig( - namedtuple( - "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"] - ) -): +class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): pass @@ -184,7 +180,6 @@ def _check_sigs_on_pdus( pdus_to_check = [ PduToCheckSig( pdu=p, - redacted_pdu_json=prune_event(p).get_pdu_json(), sender_domain=get_domain_from_id(p.sender), deferreds=[], ) @@ -195,13 +190,12 @@ def _check_sigs_on_pdus( # (except if its a 3pid invite, in which case it may be sent by any server) pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] - more_deferreds = keyring.verify_json_objects_for_server( + more_deferreds = keyring.verify_events_for_server( [ ( p.sender_domain, - p.redacted_pdu_json, + p.pdu, p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - p.pdu.event_id, ) for p in pdus_to_check_sender ] @@ -230,13 +224,12 @@ def _check_sigs_on_pdus( if p.sender_domain != get_domain_from_id(p.pdu.event_id) ] - more_deferreds = keyring.verify_json_objects_for_server( + more_deferreds = keyring.verify_events_for_server( [ ( get_domain_from_id(p.pdu.event_id), - p.redacted_pdu_json, + p.pdu, p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - p.pdu.event_id, ) for p in pdus_to_check_event_id ] diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index a5b6a61195..20812e706e 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -33,6 +33,7 @@ from typing import ( ) import attr +import ijson from prometheus_client import Counter from twisted.internet import defer @@ -55,11 +56,16 @@ from synapse.api.room_versions import ( ) from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json -from synapse.logging.context import make_deferred_yieldable, preserve_fn +from synapse.logging.context import ( + get_thread_resource_usage, + make_deferred_yieldable, + preserve_fn, +) from synapse.logging.utils import log_function from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.iterutils import batch_iter from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: @@ -667,19 +673,37 @@ class FederationClient(FederationBase): async def send_request(destination) -> Dict[str, Any]: content = await self._do_send_join(destination, pdu) - logger.debug("Got content: %s", content) + # logger.debug("Got content: %s", content.getvalue()) - state = [ - event_from_pdu_json(p, room_version, outlier=True) - for p in content.get("state", []) - ] + # logger.info("send_join content: %d", len(content)) - auth_chain = [ - event_from_pdu_json(p, room_version, outlier=True) - for p in content.get("auth_chain", []) - ] + content.seek(0) + + r = get_thread_resource_usage() + logger.info("Memory before state: %s", r.ru_maxrss) + + state = [] + for i, p in enumerate(ijson.items(content, "state.item")): + state.append(event_from_pdu_json(p, room_version, outlier=True)) + if i % 1000 == 999: + await self._clock.sleep(0) + + r = get_thread_resource_usage() + logger.info("Memory after state: %s", r.ru_maxrss) + + logger.info("Parsed state: %d", len(state)) + content.seek(0) + + auth_chain = [] + for i, p in enumerate(ijson.items(content, "auth_chain.item")): + auth_chain.append(event_from_pdu_json(p, room_version, outlier=True)) + if i % 1000 == 999: + await self._clock.sleep(0) + + r = get_thread_resource_usage() + logger.info("Memory after: %s", r.ru_maxrss) - pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} + logger.info("Parsed auth chain: %d", len(auth_chain)) create_event = None for e in state: @@ -704,12 +728,19 @@ class FederationClient(FederationBase): % (create_room_version,) ) - valid_pdus = await self._check_sigs_and_hash_and_fetch( - destination, - list(pdus.values()), - outlier=True, - room_version=room_version, - ) + valid_pdus = [] + + for chunk in batch_iter(itertools.chain(state, auth_chain), 1000): + logger.info("Handling next _check_sigs_and_hash_and_fetch chunk") + new_valid_pdus = await self._check_sigs_and_hash_and_fetch( + destination, + chunk, + outlier=True, + room_version=room_version, + ) + valid_pdus.extend(new_valid_pdus) + + logger.info("_check_sigs_and_hash_and_fetch done") valid_pdus_map = {p.event_id: p for p in valid_pdus} @@ -744,6 +775,8 @@ class FederationClient(FederationBase): % (auth_chain_create_events,) ) + logger.info("Returning from send_join") + return { "state": signed_state, "auth_chain": signed_auth, @@ -769,6 +802,8 @@ class FederationClient(FederationBase): if not self._is_unknown_endpoint(e): raise + raise NotImplementedError() + logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") resp = await self.transport_layer.send_join_v1( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index ada322a81e..9c0a105f36 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -244,7 +244,10 @@ class TransportLayerClient: path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = await self.client.put_json( - destination=destination, path=path, data=content + destination=destination, + path=path, + data=content, + return_string_io=True, ) return response @@ -254,7 +257,10 @@ class TransportLayerClient: path = _create_v2_path("/send_join/%s/%s", room_id, event_id) response = await self.client.put_json( - destination=destination, path=path, data=content + destination=destination, + path=path, + data=content, + return_string_io=True, ) return response diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index de1b14cde3..4064a2b859 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -78,7 +78,7 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) servers = {get_domain_from_id(u) for u in users} if not servers: @@ -270,7 +270,7 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND, ) - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) extra_servers = {get_domain_from_id(u) for u in users} servers = set(extra_servers) | set(servers) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index d82144d7fa..f134f1e234 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -103,7 +103,7 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = await self.state.get_current_users_in_room( + users = await self.store.get_users_in_room( event.room_id ) # type: Iterable[str] else: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 9d867aaf4d..69055a14b3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1452,7 +1452,7 @@ class FederationHandler(BaseHandler): # room stuff after join currently doesn't work on workers. assert self.config.worker.worker_app is None - logger.debug("Joining %s to %s", joinee, room_id) + logger.info("Joining %s to %s", joinee, room_id) origin, event, room_version_obj = await self._make_and_verify_event( target_hosts, @@ -1463,6 +1463,8 @@ class FederationHandler(BaseHandler): params={"ver": KNOWN_ROOM_VERSIONS}, ) + logger.info("make_join done from %s", origin) + # This shouldn't happen, because the RoomMemberHandler has a # linearizer lock which only allows one operation per user per room # at a time - so this is just paranoia. @@ -1482,10 +1484,13 @@ class FederationHandler(BaseHandler): except ValueError: pass + logger.info("Sending join") ret = await self.federation_client.send_join( host_list, event, room_version_obj ) + logger.info("send join done") + origin = ret["origin"] state = ret["state"] auth_chain = ret["auth_chain"] @@ -1510,10 +1515,14 @@ class FederationHandler(BaseHandler): room_version=room_version_obj, ) + logger.info("Persisting auth true") + max_stream_id = await self._persist_auth_tree( origin, room_id, auth_chain, state, event, room_version_obj ) + logger.info("Persisted auth true") + # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. await self._replication.wait_for_stream_position( @@ -2166,6 +2175,8 @@ class FederationHandler(BaseHandler): ctx = await self.state_handler.compute_event_context(e) events_to_context[e.event_id] = ctx + logger.info("Computed contexts") + event_map = { e.event_id: e for e in itertools.chain(auth_events, state, [event]) } @@ -2207,6 +2218,8 @@ class FederationHandler(BaseHandler): else: logger.info("Failed to find auth event %r", e_id) + logger.info("Got missing events") + for e in itertools.chain(auth_events, state, [event]): auth_for_e = { (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] @@ -2231,6 +2244,8 @@ class FederationHandler(BaseHandler): raise events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR + logger.info("Authed events") + await self.persist_events_and_notify( room_id, [ @@ -2239,10 +2254,14 @@ class FederationHandler(BaseHandler): ], ) + logger.info("Persisted events") + new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) + logger.info("Computed context") + return await self.persist_events_and_notify( room_id, [(event, new_event_context)] ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 49f8aa25ea..393f17c3a3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -258,7 +258,7 @@ class MessageHandler: "Getting joined members after leaving is not implemented" ) - users_with_profile = await self.state.get_current_users_in_room(room_id) + users_with_profile = await self.store.get_users_in_room_with_profiles(room_id) # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 5a888b7941..fb4823a5cc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1327,7 +1327,7 @@ class RoomShutdownHandler: new_room_id = None logger.info("Shutting down room %r", room_id) - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) kicked_users = [] failed_to_kick_users = [] for user_id in users: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a9a3ee05c3..0fcc1532da 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1190,7 +1190,7 @@ class SyncHandler: # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: - joined_users = await self.state.get_current_users_in_room(room_id) + joined_users = await self.store.get_users_in_room(room_id) newly_joined_or_invited_users.update(joined_users) # TODO: Check that these users are actually new, i.e. either they @@ -1206,7 +1206,7 @@ class SyncHandler: # Now find users that we no longer track for room_id in newly_left_rooms: - left_users = await self.state.get_current_users_in_room(room_id) + left_users = await self.store.get_users_in_room(room_id) newly_left_users.update(left_users) # Remove any users that we still share a room with. @@ -1361,7 +1361,7 @@ class SyncHandler: extra_users_ids = set(newly_joined_or_invited_users) for room_id in newly_joined_rooms: - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index bb837b7b19..6db1aece35 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -154,6 +154,7 @@ async def _handle_json_response( request: MatrixFederationRequest, response: IResponse, start_ms: int, + return_string_io=False, ) -> JsonDict: """ Reads the JSON body of a response, with a timeout @@ -175,12 +176,12 @@ async def _handle_json_response( d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) - def parse(_len: int): - return json_decoder.decode(buf.getvalue()) + await make_deferred_yieldable(d) - d.addCallback(parse) - - body = await make_deferred_yieldable(d) + if return_string_io: + body = buf + else: + body = json_decoder.decode(buf.getvalue()) except BodyExceededMaxSize as e: # The response was too big. logger.warning( @@ -225,12 +226,13 @@ async def _handle_json_response( time_taken_secs = reactor.seconds() - start_ms / 1000 logger.info( - "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s", + "{%s} [%s] Completed request: %d %s in %.2f secs got %dB - %s %s", request.txn_id, request.destination, response.code, response.phrase.decode("ascii", errors="replace"), time_taken_secs, + len(buf.getvalue()), request.method, request.uri.decode("ascii"), ) @@ -683,6 +685,7 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, + return_string_io=False, ) -> Union[JsonDict, list]: """Sends the specified json data using PUT @@ -757,7 +760,12 @@ class MatrixFederationHttpClient: _sec_timeout = self.default_timeout body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms + self.reactor, + _sec_timeout, + request, + response, + start_ms, + return_string_io=return_string_io, ) return body diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 31b7b3c256..c841363b1e 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ctypes +import ctypes.util import functools import gc import itertools import logging import os import platform +import re import threading import time from typing import Callable, Dict, Iterable, Optional, Tuple, Union @@ -535,6 +538,13 @@ class ReactorLastSeenMetric: REGISTRY.register(ReactorLastSeenMetric()) +# The minimum time in seconds between GCs for each generation, regardless of the current GC +# thresholds and counts. +MIN_TIME_BETWEEN_GCS = [1, 10, 30] + +# The time in seconds of the last time we did a GC for each generation. +_last_gc = [0, 0, 0] + def runUntilCurrentTimer(reactor, func): @functools.wraps(func) @@ -575,11 +585,16 @@ def runUntilCurrentTimer(reactor, func): return ret # Check if we need to do a manual GC (since its been disabled), and do - # one if necessary. + # one if necessary. Note we go in reverse order as e.g. a gen 1 GC may + # promote an object into gen 2, and we don't want to handle the same + # object multiple times. threshold = gc.get_threshold() counts = gc.get_count() for i in (2, 1, 0): - if threshold[i] < counts[i]: + # We check if we need to do one based on a straightforward + # comparison between the threshold and count. We also do an extra + # check to make sure that we don't a GC too often. + if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]: if i == 0: logger.debug("Collecting gc %d", i) else: @@ -589,6 +604,8 @@ def runUntilCurrentTimer(reactor, func): unreachable = gc.collect(i) end = time.time() + _last_gc[i] = int(end) + gc_time.labels(i).observe(end - start) gc_unreachable.labels(i).set(unreachable) @@ -597,6 +614,163 @@ def runUntilCurrentTimer(reactor, func): return f +def _setup_jemalloc_stats(): + """Checks to see if jemalloc is loaded, and hooks up a collector to record + statistics exposed by jemalloc. + """ + + # Try to find the loaded jemalloc shared library, if any. We need to + # introspect into what is loaded, rather than loading whatever is on the + # path, as if we load a *different* jemalloc version things will seg fault. + pid = os.getpid() + + # We're looking for a path at the end of the line that includes + # "libjemalloc". + regex = re.compile(r"/\S+/libjemalloc.*$") + + jemalloc_path = None + with open(f"/proc/{pid}/maps") as f: + for line in f.readlines(): + match = regex.search(line.strip()) + if match: + jemalloc_path = match.group() + + if not jemalloc_path: + # No loaded jemalloc was found. + return + + jemalloc = ctypes.CDLL(jemalloc_path) + + def _mallctl( + name: str, read: bool = True, write: Optional[int] = None + ) -> Optional[int]: + """Wrapper around `mallctl` for reading and writing integers to + jemalloc. + + Args: + name: The name of the option to read from/write to. + read: Whether to try and read the value. + write: The value to write, if given. + + Returns: + The value read if `read` is True, otherwise None. + + Raises: + An exception if `mallctl` returns a non-zero error code. + """ + + input_var = None + input_var_ref = None + input_len_ref = None + if read: + input_var = ctypes.c_size_t(0) + input_len = ctypes.c_size_t(ctypes.sizeof(input_var)) + + input_var_ref = ctypes.byref(input_var) + input_len_ref = ctypes.byref(input_len) + + write_var_ref = None + write_len = ctypes.c_size_t(0) + if write is not None: + write_var = ctypes.c_size_t(write) + write_len = ctypes.c_size_t(ctypes.sizeof(write_var)) + + write_var_ref = ctypes.byref(write_var) + + # The interface is: + # + # int mallctl( + # const char *name, + # void *oldp, + # size_t *oldlenp, + # void *newp, + # size_t newlen + # ) + # + # Where oldp/oldlenp is a buffer where the old value will be written to + # (if not null), and newp/newlen is the buffer with the new value to set + # (if not null). Note that they're all references *except* newlen. + result = jemalloc.mallctl( + name.encode("ascii"), + input_var_ref, + input_len_ref, + write_var_ref, + write_len, + ) + + if result != 0: + raise Exception("Failed to call mallctl") + + if input_var is None: + return None + + return input_var.value + + def _jemalloc_refresh_stats() -> None: + """Request that jemalloc updates its internal statistics. This needs to + be called before querying for stats, otherwise it will return stale + values. + """ + try: + _mallctl("epoch", read=False, write=1) + except Exception: + pass + + class JemallocCollector: + """Metrics for internal jemalloc stats.""" + + def collect(self): + _jemalloc_refresh_stats() + + g = GaugeMetricFamily( + "jemalloc_stats_app_memory", + "The stats reported by jemalloc", + labels=["type"], + ) + + # Read the relevant global stats from jemalloc. Note that these may + # not be accurate if python is configured to use its internal small + # object allocator (which is on by default, disable by setting the + # env `PYTHONMALLOC=malloc`). + # + # See the jemalloc manpage for details about what each value means, + # roughly: + # - allocated ─ Total number of bytes allocated by the app + # - active ─ Total number of bytes in active pages allocated by + # the application, this is bigger than `allocated`. + # - resident ─ Maximum number of bytes in physically resident data + # pages mapped by the allocator, comprising all pages dedicated + # to allocator metadata, pages backing active allocations, and + # unused dirty pages. This is bigger than `active`. + # - mapped ─ Total number of bytes in active extents mapped by the + # allocator. + # - metadata ─ Total number of bytes dedicated to jemalloc + # metadata. + for t in ( + "allocated", + "active", + "resident", + "mapped", + "metadata", + ): + try: + value = _mallctl(f"stats.{t}") + except Exception: + # There was an error fetching the value, skip. + continue + + g.add_metric([t], value=value) + + yield g + + REGISTRY.register(JemallocCollector()) + + +try: + _setup_jemalloc_stats() +except Exception: + logger.info("Failed to setup collector to record jemalloc stats.") + try: # Ensure the reactor has all the attributes we expect reactor.seconds # type: ignore @@ -615,6 +789,7 @@ try: except AttributeError: pass + __all__ = [ "MetricsResource", "generate_latest", diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py index 1a3b051e3c..b402f82810 100644 --- a/synapse/replication/tcp/external_cache.py +++ b/synapse/replication/tcp/external_cache.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Any, Optional -from prometheus_client import Counter +from prometheus_client import Counter, Histogram from synapse.logging.context import make_deferred_yieldable from synapse.util import json_decoder, json_encoder @@ -35,6 +35,20 @@ get_counter = Counter( labelnames=["cache_name", "hit"], ) +response_timer = Histogram( + "synapse_external_cache_response_time_seconds", + "Time taken to get a response from Redis for a cache get/set request", + labelnames=["method"], + buckets=( + 0.001, + 0.002, + 0.005, + 0.01, + 0.02, + 0.05, + ), +) + logger = logging.getLogger(__name__) @@ -72,13 +86,14 @@ class ExternalCache: logger.debug("Caching %s %s: %r", cache_name, key, encoded_value) - return await make_deferred_yieldable( - self._redis_connection.set( - self._get_redis_key(cache_name, key), - encoded_value, - pexpire=expiry_ms, + with response_timer.labels("set").time(): + return await make_deferred_yieldable( + self._redis_connection.set( + self._get_redis_key(cache_name, key), + encoded_value, + pexpire=expiry_ms, + ) ) - ) async def get(self, cache_name: str, key: str) -> Optional[Any]: """Look up a key/value in the named cache.""" @@ -86,9 +101,10 @@ class ExternalCache: if self._redis_connection is None: return None - result = await make_deferred_yieldable( - self._redis_connection.get(self._get_redis_key(cache_name, key)) - ) + with response_timer.labels("get").time(): + result = await make_deferred_yieldable( + self._redis_connection.get(self._get_redis_key(cache_name, key)) + ) logger.debug("Got cache result %s %s: %r", cache_name, key, result) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index f648678b09..e19e9ef5c7 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.util import json_decoder +from synapse.util.async_helpers import yieldable_gather_results logger = logging.getLogger(__name__) @@ -213,7 +214,13 @@ class RemoteKey(DirectServeJsonResource): # If there is a cache miss, request the missing keys, then recurse (and # ensure the result is sent). if cache_misses and query_remote_on_cache_miss: - await self.fetcher.get_keys(cache_misses) + await yieldable_gather_results( + lambda t: self.fetcher.get_keys(*t), + ( + (server_name, list(keys), 0) + for server_name, keys in cache_misses.items() + ), + ) await self.query_keys(request, query, query_remote_on_cache_miss=False) else: signed_keys = [] diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index b3bd92d37c..a1770f620e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -213,19 +213,23 @@ class StateHandler: return ret.state async def get_current_users_in_room( - self, room_id: str, latest_event_ids: Optional[List[str]] = None + self, room_id: str, latest_event_ids: List[str] ) -> Dict[str, ProfileInfo]: """ Get the users who are currently in a room. + Note: This is much slower than using the equivalent method + `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`, + so this should only be used when wanting the users at a particular point + in the room. + Args: room_id: The ID of the room. latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. Returns: Dictionary of user IDs to their profileinfo. """ - if not latest_event_ids: - latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_users_in_room") diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 6b68d8720c..3d98d3f5f8 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -69,6 +69,7 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) + self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,)) self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 2a8532f8c1..5fc3bb5a7d 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -205,8 +205,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]: sql = """ - SELECT user_id, display_name, avatar_url FROM room_memberships - WHERE room_id = ? AND membership = ? + SELECT state_key, display_name, avatar_url FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? """ txn.execute(sql, (room_id, Membership.JOIN)) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 7a082fdd21..a6bfb4902a 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -142,8 +142,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): batch_size (int): Maximum number of state events to process per cycle. """ - state = self.hs.get_state_handler() - # If we don't have progress filed, delete everything. if not progress: await self.delete_all_from_user_dir() @@ -197,7 +195,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): room_id ) - users_with_profile = await state.get_current_users_in_room(room_id) + users_with_profile = await self.get_users_in_room_with_profiles(room_id) user_ids = set(users_with_profile) # Update each user in the user directory. |