diff options
36 files changed, 846 insertions, 598 deletions
diff --git a/changelog.d/9881.feature b/changelog.d/9881.feature new file mode 100644 index 0000000000..088a517e02 --- /dev/null +++ b/changelog.d/9881.feature @@ -0,0 +1 @@ +Add experimental option to track memory usage of the caches. diff --git a/changelog.d/9882.misc b/changelog.d/9882.misc new file mode 100644 index 0000000000..facfa31f38 --- /dev/null +++ b/changelog.d/9882.misc @@ -0,0 +1 @@ +Export jemalloc stats to Prometheus if it is being used. diff --git a/changelog.d/9902.feature b/changelog.d/9902.feature new file mode 100644 index 0000000000..4d9f324d4e --- /dev/null +++ b/changelog.d/9902.feature @@ -0,0 +1 @@ +Add limits to how often Synapse will GC, ensuring that large servers do not end up GC thrashing if `gc_thresholds` has not been correctly set. diff --git a/changelog.d/9910.bugfix b/changelog.d/9910.bugfix new file mode 100644 index 0000000000..06d523fd46 --- /dev/null +++ b/changelog.d/9910.bugfix @@ -0,0 +1 @@ +Fix bug where user directory could get out of sync if room visibility and membership changed in quick succession. diff --git a/changelog.d/9910.feature b/changelog.d/9910.feature new file mode 100644 index 0000000000..54165cce18 --- /dev/null +++ b/changelog.d/9910.feature @@ -0,0 +1 @@ +Improve performance after joining a large room when presence is enabled. diff --git a/changelog.d/9916.misc b/changelog.d/9916.misc new file mode 100644 index 0000000000..401298fa3d --- /dev/null +++ b/changelog.d/9916.misc @@ -0,0 +1 @@ +Improve performance of handling presence when joining large rooms. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d013725cdc..f469d6e54f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -152,6 +152,16 @@ presence: # #gc_thresholds: [700, 10, 10] +# The minimum time in seconds between each GC for a generation, regardless of +# the GC thresholds. This ensures that we don't do GC too frequently. +# +# A value of `[1s, 10s, 30s]` indicates that a second must pass between consecutive +# generation 0 GCs, etc. +# +# Defaults to `[1s, 10s, 30s]`. +# +#gc_min_interval: [0.5s, 30s, 1m] + # Set the limit on the returned events in the timeline in the get # and sync operations. The default value is 100. -1 means no upper limit. # diff --git a/mypy.ini b/mypy.ini index a40f705b76..ea655a0d4d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -171,3 +171,6 @@ ignore_missing_imports = True [mypy-txacme.*] ignore_missing_imports = True + +[mypy-pympler.*] +ignore_missing_imports = True diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 638e01c1b2..59918d789e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -37,6 +37,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.metrics.jemalloc import setup_jemalloc_stats from synapse.util.async_helpers import Linearizer from synapse.util.daemonize import daemonize_process from synapse.util.rlimit import change_resource_limit @@ -115,6 +116,7 @@ def start_reactor( def run(): logger.info("Running") + setup_jemalloc_stats() change_resource_limit(soft_file_limit) if gc_thresholds: gc.set_threshold(*gc_thresholds) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1a15ceee81..f730cdbd78 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -454,6 +454,10 @@ def start(config_options): config.server.update_user_directory = False synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage + + if config.server.gc_seconds: + synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds hs = GenericWorkerServer( config.server_name, diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8e78134bbe..b2501ee4d7 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -341,6 +341,10 @@ def setup(config_options): sys.exit(0) events.USE_FROZEN_DICTS = config.use_frozen_dicts + synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage + + if config.server.gc_seconds: + synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds hs = SynapseHomeServer( config.server_name, diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 41b9b3f51f..91165ee1ce 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -17,6 +17,8 @@ import re import threading from typing import Callable, Dict +from synapse.python_dependencies import DependencyException, check_requirements + from ._base import Config, ConfigError # The prefix for all cache factor-related environment variables @@ -189,6 +191,15 @@ class CacheConfig(Config): ) self.cache_factors[cache] = factor + self.track_memory_usage = cache_config.get("track_memory_usage", False) + if self.track_memory_usage: + try: + check_requirements("cache_memory") + except DependencyException as e: + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) + # Resize all caches (if necessary) with the new factors we've loaded self.resize_all_caches() diff --git a/synapse/config/server.py b/synapse/config/server.py index 21ca7b33e3..c290a35a92 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,7 +19,7 @@ import logging import os.path import re from textwrap import indent -from typing import Any, Dict, Iterable, List, Optional, Set +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple import attr import yaml @@ -572,6 +572,7 @@ class ServerConfig(Config): _warn_if_webclient_configured(self.listeners) self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) + self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None)) @attr.s class LimitRemoteRoomsConfig: @@ -917,6 +918,16 @@ class ServerConfig(Config): # #gc_thresholds: [700, 10, 10] + # The minimum time in seconds between each GC for a generation, regardless of + # the GC thresholds. This ensures that we don't do GC too frequently. + # + # A value of `[1s, 10s, 30s]` indicates that a second must pass between consecutive + # generation 0 GCs, etc. + # + # Defaults to `[1s, 10s, 30s]`. + # + #gc_min_interval: [0.5s, 30s, 1m] + # Set the limit on the returned events in the timeline in the get # and sync operations. The default value is 100. -1 means no upper limit. # @@ -1305,6 +1316,24 @@ class ServerConfig(Config): help="Turn on the twisted telnet manhole service on the given port.", ) + def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]: + """Reads the three durations for the GC min interval option, returning seconds.""" + if durations is None: + return None + + try: + if len(durations) != 3: + raise ValueError() + return ( + self.parse_duration(durations[0]) / 1000, + self.parse_duration(durations[1]) / 1000, + self.parse_duration(durations[2]) / 1000, + ) + except Exception: + raise ConfigError( + "Value of `gc_min_interval` must be a list of three durations if set" + ) + def is_threepid_reserved(reserved_threepids, threepid): """Check the threepid against the reserved threepid config diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 5f18ef7748..9efe13606a 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -16,8 +16,7 @@ import abc import logging import urllib -from collections import defaultdict -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple import attr from signedjson.key import ( @@ -42,17 +41,18 @@ from synapse.api.errors import ( SynapseError, ) from synapse.config.key import TrustedKeyServer +from synapse.events import EventBase +from synapse.events.utils import prune_event_dict from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, - preserve_fn, run_in_background, ) +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict from synapse.util import unwrapFirstError -from synapse.util.async_helpers import yieldable_gather_results -from synapse.util.metrics import Measure +from synapse.util.async_helpers import Linearizer, yieldable_gather_results from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: @@ -74,8 +74,6 @@ class VerifyJsonRequest: minimum_valid_until_ts: time at which we require the signing key to be valid. (0 implies we don't care) - request_name: The name of the request. - key_ids: The set of key_ids to that could be used to verify the JSON object key_ready (Deferred[str, str, nacl.signing.VerifyKey]): @@ -88,20 +86,93 @@ class VerifyJsonRequest: """ server_name = attr.ib(type=str) - json_object = attr.ib(type=JsonDict) + json_object_callback = attr.ib(type=Callable[[], JsonDict]) minimum_valid_until_ts = attr.ib(type=int) - request_name = attr.ib(type=str) - key_ids = attr.ib(init=False, type=List[str]) - key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred) + key_ids = attr.ib(type=List[str]) - def __attrs_post_init__(self): - self.key_ids = signature_ids(self.json_object, self.server_name) + @staticmethod + def from_json_object( + server_name: str, minimum_valid_until_ms: int, json_object: JsonDict + ): + key_ids = signature_ids(json_object, server_name) + return VerifyJsonRequest( + server_name, lambda: json_object, minimum_valid_until_ms, key_ids + ) + + @staticmethod + def from_event( + server_name: str, + minimum_valid_until_ms: int, + event: EventBase, + ): + key_ids = list(event.signatures.get(server_name, [])) + return VerifyJsonRequest( + server_name, + lambda: prune_event_dict(event.room_version, event.get_pdu_json()), + minimum_valid_until_ms, + key_ids, + ) class KeyLookupError(ValueError): pass +@attr.s(slots=True) +class _QueueValue: + server_name = attr.ib(type=str) + minimum_valid_until_ts = attr.ib(type=int) + key_ids = attr.ib(type=List[str]) + + +class _Queue: + def __init__(self, name, clock, process_items): + self._name = name + self._clock = clock + self._is_processing = False + self._next_values = [] + + self.process_items = process_items + + async def add_to_queue(self, value: _QueueValue) -> Dict[str, FetchKeyResult]: + d = defer.Deferred() + self._next_values.append((value, d)) + + if not self._is_processing: + run_as_background_process(self._name, self._unsafe_process) + + return await make_deferred_yieldable(d) + + async def _unsafe_process(self): + try: + if self._is_processing: + return + + self._is_processing = True + + while self._next_values: + # We purposefully defer to the next loop. + await self._clock.sleep(0) + + next_values = self._next_values + self._next_values = [] + + try: + values = [value for value, _ in next_values] + results = await self.process_items(values) + + for value, deferred in next_values: + with PreserveLoggingContext(): + deferred.callback(results.get(value.server_name, {})) + + except Exception as e: + for _, deferred in next_values: + deferred.errback(e) + + finally: + self._is_processing = False + + class Keyring: def __init__( self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None @@ -116,12 +187,7 @@ class Keyring: ) self._key_fetchers = key_fetchers - # map from server name to Deferred. Has an entry for each server with - # an ongoing key download; the Deferred completes once the download - # completes. - # - # These are regular, logcontext-agnostic Deferreds. - self.key_downloads = {} # type: Dict[str, defer.Deferred] + self._server_queue = Linearizer("keyring_server") def verify_json_for_server( self, @@ -130,365 +196,150 @@ class Keyring: validity_time: int, request_name: str, ) -> defer.Deferred: - """Verify that a JSON object has been signed by a given server - - Args: - server_name: name of the server which must have signed this object - - json_object: object to be checked - - validity_time: timestamp at which we require the signing key to - be valid. (0 implies we don't care) - - request_name: an identifier for this json object (eg, an event id) - for logging. - - Returns: - Deferred[None]: completes if the the object was correctly signed, otherwise - errbacks with an error - """ - req = VerifyJsonRequest(server_name, json_object, validity_time, request_name) - requests = (req,) - return make_deferred_yieldable(self._verify_objects(requests)[0]) + request = VerifyJsonRequest.from_json_object( + server_name, + validity_time, + json_object, + ) + return defer.ensureDeferred(self._verify_object(request)) def verify_json_objects_for_server( self, server_and_json: Iterable[Tuple[str, dict, int, str]] ) -> List[defer.Deferred]: - """Bulk verifies signatures of json objects, bulk fetching keys as - necessary. - - Args: - server_and_json: - Iterable of (server_name, json_object, validity_time, request_name) - tuples. - - validity_time is a timestamp at which the signing key must be - valid. - - request_name is an identifier for this json object (eg, an event id) - for logging. - - Returns: - List<Deferred[None]>: for each input triplet, a deferred indicating success - or failure to verify each json object's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. - """ - return self._verify_objects( - VerifyJsonRequest(server_name, json_object, validity_time, request_name) - for server_name, json_object, validity_time, request_name in server_and_json - ) - - def _verify_objects( - self, verify_requests: Iterable[VerifyJsonRequest] - ) -> List[defer.Deferred]: - """Does the work of verify_json_[objects_]for_server - - - Args: - verify_requests: Iterable of verification requests. - - Returns: - List<Deferred[None]>: for each input item, a deferred indicating success - or failure to verify each json object's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. - """ - # a list of VerifyJsonRequests which are awaiting a key lookup - key_lookups = [] - handle = preserve_fn(_handle_key_deferred) - - def process(verify_request: VerifyJsonRequest) -> defer.Deferred: - """Process an entry in the request list - - Adds a key request to key_lookups, and returns a deferred which - will complete or fail (in the sentinel context) when verification completes. - """ - if not verify_request.key_ids: - return defer.fail( - SynapseError( - 400, - "Not signed by %s" % (verify_request.server_name,), - Codes.UNAUTHORIZED, - ) + return [ + defer.ensureDeferred( + run_in_background( + self._verify_object, + VerifyJsonRequest.from_json_object( + server_name, + validity_time, + json_object, + ), ) - - logger.debug( - "Verifying %s for %s with key_ids %s, min_validity %i", - verify_request.request_name, - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, ) + for server_name, json_object, validity_time, request_name in server_and_json + ] - # add the key request to the queue, but don't start it off yet. - key_lookups.append(verify_request) - - # now run _handle_key_deferred, which will wait for the key request - # to complete and then do the verification. - # - # We want _handle_key_request to log to the right context, so we - # wrap it with preserve_fn (aka run_in_background) - return handle(verify_request) - - results = [process(r) for r in verify_requests] - - if key_lookups: - run_in_background(self._start_key_lookups, key_lookups) - - return results - - async def _start_key_lookups( - self, verify_requests: List[VerifyJsonRequest] - ) -> None: - """Sets off the key fetches for each verify request - - Once each fetch completes, verify_request.key_ready will be resolved. - - Args: - verify_requests: - """ - - try: - # map from server name to a set of outstanding request ids - server_to_request_ids = {} # type: Dict[str, Set[int]] - - for verify_request in verify_requests: - server_name = verify_request.server_name - request_id = id(verify_request) - server_to_request_ids.setdefault(server_name, set()).add(request_id) - - # Wait for any previous lookups to complete before proceeding. - await self.wait_for_previous_lookups(server_to_request_ids.keys()) - - # take out a lock on each of the servers by sticking a Deferred in - # key_downloads - for server_name in server_to_request_ids.keys(): - self.key_downloads[server_name] = defer.Deferred() - logger.debug("Got key lookup lock on %s", server_name) - - # When we've finished fetching all the keys for a given server_name, - # drop the lock by resolving the deferred in key_downloads. - def drop_server_lock(server_name): - d = self.key_downloads.pop(server_name) - d.callback(None) - - def lookup_done(res, verify_request): - server_name = verify_request.server_name - server_requests = server_to_request_ids[server_name] - server_requests.remove(id(verify_request)) - - # if there are no more requests for this server, we can drop the lock. - if not server_requests: - logger.debug("Releasing key lookup lock on %s", server_name) - drop_server_lock(server_name) - - return res - - for verify_request in verify_requests: - verify_request.key_ready.addBoth(lookup_done, verify_request) - - # Actually start fetching keys. - self._get_server_verify_keys(verify_requests) - except Exception: - logger.exception("Error starting key lookups") - - async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None: - """Waits for any previous key lookups for the given servers to finish. - - Args: - server_names: list of servers which we want to look up - - Returns: - Resolves once all key lookups for the given servers have - completed. Follows the synapse rules of logcontext preservation. - """ - loop_count = 1 - while True: - wait_on = [ - (server_name, self.key_downloads[server_name]) - for server_name in server_names - if server_name in self.key_downloads - ] - if not wait_on: - break - logger.info( - "Waiting for existing lookups for %s to complete [loop %i]", - [w[0] for w in wait_on], - loop_count, + def verify_events_for_server( + self, server_and_json: Iterable[Tuple[str, EventBase, int]] + ) -> List[defer.Deferred]: + return [ + run_in_background( + self._verify_object, + VerifyJsonRequest.from_event( + server_name, + validity_time, + event, + ), ) - with PreserveLoggingContext(): - await defer.DeferredList((w[1] for w in wait_on)) - - loop_count += 1 - - def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None: - """Tries to find at least one key for each verify request - - For each verify_request, verify_request.key_ready is called back with - params (server_name, key_id, VerifyKey) if a key is found, or errbacked - with a SynapseError if none of the keys are found. + for server_name, event, validity_time in server_and_json + ] + + async def _verify_object(self, verify_request: VerifyJsonRequest): + # TODO: Use a batching thing. + with (await self._server_queue.queue(verify_request.server_name)): + found_keys: Dict[str, FetchKeyResult] = {} + missing_key_ids = set(verify_request.key_ids) + for fetcher in self._key_fetchers: + if not missing_key_ids: + break + + keys = await fetcher.get_keys( + verify_request.server_name, + list(missing_key_ids), + verify_request.minimum_valid_until_ts, + ) - Args: - verify_requests: list of verify requests - """ + for key_id, key in keys.items(): + if not key: + continue - remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} + if key.valid_until_ts < verify_request.minimum_valid_until_ts: + continue - async def do_iterations(): - try: - with Measure(self.clock, "get_server_verify_keys"): - for f in self._key_fetchers: - if not remaining_requests: - return - await self._attempt_key_fetches_with_fetcher( - f, remaining_requests - ) - - # look for any requests which weren't satisfied - while remaining_requests: - verify_request = remaining_requests.pop() - rq_str = ( - "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ) - ) - - # If we run the errback immediately, it may cancel our - # loggingcontext while we are still in it, so instead we - # schedule it for the next time round the reactor. - # - # (this also ensures that we don't get a stack overflow if we - # has a massive queue of lookups waiting for this server). - self.clock.call_later( - 0, - verify_request.key_ready.errback, - SynapseError( - 401, - "Failed to find any key to satisfy %s" % (rq_str,), - Codes.UNAUTHORIZED, - ), - ) - except Exception as err: - # we don't really expect to get here, because any errors should already - # have been caught and logged. But if we do, let's log the error and make - # sure that all of the deferreds are resolved. - logger.error("Unexpected error in _get_server_verify_keys: %s", err) - with PreserveLoggingContext(): - for verify_request in remaining_requests: - if not verify_request.key_ready.called: - verify_request.key_ready.errback(err) - - run_in_background(do_iterations) - - async def _attempt_key_fetches_with_fetcher( - self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest] - ): - """Use a key fetcher to attempt to satisfy some key requests + existing_key = found_keys.get(key_id) + if existing_key: + if key.valid_until_ts <= existing_key.valid_until_ts: + continue - Args: - fetcher: fetcher to use to fetch the keys - remaining_requests: outstanding key requests. - Any successfully-completed requests will be removed from the list. - """ - # The keys to fetch. - # server_name -> key_id -> min_valid_ts - missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]] + found_keys[key_id] = key - for verify_request in remaining_requests: - # any completed requests should already have been removed - assert not verify_request.key_ready.called - keys_for_server = missing_keys[verify_request.server_name] + missing_key_ids.difference_update(found_keys) - for key_id in verify_request.key_ids: - # If we have several requests for the same key, then we only need to - # request that key once, but we should do so with the greatest - # min_valid_until_ts of the requests, so that we can satisfy all of - # the requests. - keys_for_server[key_id] = max( - keys_for_server.get(key_id, -1), - verify_request.minimum_valid_until_ts, + if missing_key_ids: + raise SynapseError( + 400, + "Missing keys for %s: %s" + % (verify_request.server_name, missing_key_ids), + Codes.UNAUTHORIZED, ) - results = await fetcher.get_keys(missing_keys) - - completed = [] - for verify_request in remaining_requests: - server_name = verify_request.server_name - - # see if any of the keys we got this time are sufficient to - # complete this VerifyJsonRequest. - result_keys = results.get(server_name, {}) for key_id in verify_request.key_ids: - fetch_key_result = result_keys.get(key_id) - if not fetch_key_result: - # we didn't get a result for this key - continue - - if ( - fetch_key_result.valid_until_ts - < verify_request.minimum_valid_until_ts - ): - # key was not valid at this point - continue - - # we have a valid key for this request. If we run the callback - # immediately, it may cancel our loggingcontext while we are still in - # it, so instead we schedule it for the next time round the reactor. - # - # (this also ensures that we don't get a stack overflow if we had - # a massive queue of lookups waiting for this server). - logger.debug( - "Found key %s:%s for %s", - server_name, - key_id, - verify_request.request_name, - ) - self.clock.call_later( - 0, - verify_request.key_ready.callback, - (server_name, key_id, fetch_key_result.verify_key), - ) - completed.append(verify_request) - break - - remaining_requests.difference_update(completed) + verify_key = found_keys[key_id].verify_key + try: + json_object = verify_request.json_object_callback() + verify_signed_json( + json_object, + verify_request.server_name, + verify_key, + ) + except SignatureVerifyException as e: + logger.debug( + "Error verifying signature for %s:%s:%s with key %s: %s", + verify_request.server_name, + verify_key.alg, + verify_key.version, + encode_verify_key_base64(verify_key), + str(e), + ) + raise SynapseError( + 401, + "Invalid signature for server %s with key %s:%s: %s" + % ( + verify_request.server_name, + verify_key.alg, + verify_key.version, + str(e), + ), + Codes.UNAUTHORIZED, + ) class KeyFetcher(metaclass=abc.ABCMeta): - @abc.abstractmethod + def __init__(self, hs: "HomeServer"): + self._queue = _Queue(self.__class__.__name__, hs.get_clock(), self._fetch_keys) + async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] - ) -> Dict[str, Dict[str, FetchKeyResult]]: - """ - Args: - keys_to_fetch: - the keys to be fetched. server_name -> key_id -> min_valid_ts + self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + return await self._queue.add_to_queue( + _QueueValue( + server_name=server_name, + key_ids=key_ids, + minimum_valid_until_ts=minimum_valid_until_ts, + ) + ) - Returns: - Map from server_name -> key_id -> FetchKeyResult - """ - raise NotImplementedError + @abc.abstractmethod + async def _fetch_keys( + self, keys_to_fetch: List[_QueueValue] + ) -> Dict[str, Dict[str, FetchKeyResult]]: + pass class StoreKeyFetcher(KeyFetcher): """KeyFetcher impl which fetches keys from our data store""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + super().__init__(hs) - async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] - ) -> Dict[str, Dict[str, FetchKeyResult]]: - """see KeyFetcher.get_keys""" + self.store = hs.get_datastore() + async def _fetch_keys(self, keys_to_fetch: List[_QueueValue]): key_ids_to_fetch = ( - (server_name, key_id) - for server_name, keys_for_server in keys_to_fetch.items() - for key_id in keys_for_server.keys() + (queue_value.server_name, key_id) + for queue_value in keys_to_fetch + for key_id in queue_value.key_ids ) res = await self.store.get_server_verify_keys(key_ids_to_fetch) @@ -500,6 +351,8 @@ class StoreKeyFetcher(KeyFetcher): class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.store = hs.get_datastore() self.config = hs.config @@ -607,10 +460,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): self.client = hs.get_federation_http_client() self.key_servers = self.config.key_servers - async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] + async def _fetch_keys( + self, keys_to_fetch: List[_QueueValue] ) -> Dict[str, Dict[str, FetchKeyResult]]: - """see KeyFetcher.get_keys""" + """see KeyFetcher._fetch_keys""" async def get_key(key_server: TrustedKeyServer) -> Dict: try: @@ -646,12 +499,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return union_of_keys async def get_server_verify_key_v2_indirect( - self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer + self, keys_to_fetch: List[_QueueValue], key_server: TrustedKeyServer ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: keys_to_fetch: - the keys to be fetched. server_name -> key_id -> min_valid_ts + the keys to be fetched. key_server: notary server to query for the keys @@ -665,7 +518,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): perspective_name = key_server.server_name logger.info( "Requesting keys %s from notary server %s", - keys_to_fetch.items(), + keys_to_fetch, perspective_name, ) @@ -675,11 +528,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): path="/_matrix/key/v2/query", data={ "server_keys": { - server_name: { - key_id: {"minimum_valid_until_ts": min_valid_ts} - for key_id, min_valid_ts in server_keys.items() + queue_value.server_name: { + key_id: { + "minimum_valid_until_ts": queue_value.minimum_valid_until_ts, + } + for key_id in queue_value.key_ids } - for server_name, server_keys in keys_to_fetch.items() + for queue_value in keys_to_fetch } }, ) @@ -779,8 +634,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher): self.clock = hs.get_clock() self.client = hs.get_federation_http_client() - async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] + async def _fetch_keys( + self, keys_to_fetch: List[_QueueValue] ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: @@ -793,8 +648,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher): results = {} - async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None: - server_name, key_ids = key_to_fetch_item + async def get_key(key_to_fetch_item: _QueueValue) -> None: + server_name = key_to_fetch_item.server_name + key_ids = key_to_fetch_item.key_ids + try: keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) results[server_name] = keys @@ -805,7 +662,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except Exception: logger.exception("Error getting keys %s from %s", key_ids, server_name) - await yieldable_gather_results(get_key, keys_to_fetch.items()) + await yieldable_gather_results(get_key, keys_to_fetch) return results async def get_server_verify_key_v2_direct( @@ -877,37 +734,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher): keys.update(response_keys) return keys - - -async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None: - """Waits for the key to become available, and then performs a verification - - Args: - verify_request: - - Raises: - SynapseError if there was a problem performing the verification - """ - server_name = verify_request.server_name - with PreserveLoggingContext(): - _, key_id, verify_key = await verify_request.key_ready - - json_object = verify_request.json_object - - try: - verify_signed_json(json_object, server_name, verify_key) - except SignatureVerifyException as e: - logger.debug( - "Error verifying signature for %s:%s:%s with key %s: %s", - server_name, - verify_key.alg, - verify_key.version, - encode_verify_key_base64(verify_key), - str(e), - ) - raise SynapseError( - 401, - "Invalid signature for server %s with key %s:%s: %s" - % (server_name, verify_key.alg, verify_key.version, str(e)), - Codes.UNAUTHORIZED, - ) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index a6471ce6cc..dd631c7794 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -73,10 +73,10 @@ class FederationBase: * throws a SynapseError if the signature check failed. The deferreds run their callbacks in the sentinel """ - deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus) - ctx = current_context() + deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus) + @defer.inlineCallbacks def callback(_, pdu: EventBase): with PreserveLoggingContext(ctx): @@ -135,11 +135,7 @@ class FederationBase: return deferreds -class PduToCheckSig( - namedtuple( - "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"] - ) -): +class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): pass @@ -182,7 +178,6 @@ def _check_sigs_on_pdus( pdus_to_check = [ PduToCheckSig( pdu=p, - redacted_pdu_json=prune_event(p).get_pdu_json(), sender_domain=get_domain_from_id(p.sender), deferreds=[], ) @@ -193,13 +188,12 @@ def _check_sigs_on_pdus( # (except if its a 3pid invite, in which case it may be sent by any server) pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] - more_deferreds = keyring.verify_json_objects_for_server( + more_deferreds = keyring.verify_events_for_server( [ ( p.sender_domain, - p.redacted_pdu_json, + p.pdu, p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - p.pdu.event_id, ) for p in pdus_to_check_sender ] @@ -228,13 +222,12 @@ def _check_sigs_on_pdus( if p.sender_domain != get_domain_from_id(p.pdu.event_id) ] - more_deferreds = keyring.verify_json_objects_for_server( + more_deferreds = keyring.verify_events_for_server( [ ( get_domain_from_id(p.pdu.event_id), - p.redacted_pdu_json, + p.pdu, p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - p.pdu.event_id, ) for p in pdus_to_check_event_id ] diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index a5b6a61195..90acc23886 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -33,6 +33,7 @@ from typing import ( ) import attr +import ijson from prometheus_client import Counter from twisted.internet import defer @@ -55,11 +56,16 @@ from synapse.api.room_versions import ( ) from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json -from synapse.logging.context import make_deferred_yieldable, preserve_fn +from synapse.logging.context import ( + get_thread_resource_usage, + make_deferred_yieldable, + preserve_fn, +) from synapse.logging.utils import log_function from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.iterutils import batch_iter from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: @@ -385,7 +391,6 @@ class FederationClient(FederationBase): Returns: A list of PDUs that have valid signatures and hashes. """ - deferreds = self._check_sigs_and_hashes(room_version, pdus) async def handle_check_result(pdu: EventBase, deferred: Deferred): try: @@ -420,6 +425,7 @@ class FederationClient(FederationBase): return res handle = preserve_fn(handle_check_result) + deferreds = self._check_sigs_and_hashes(room_version, pdus) deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] valid_pdus = await make_deferred_yieldable( @@ -667,19 +673,37 @@ class FederationClient(FederationBase): async def send_request(destination) -> Dict[str, Any]: content = await self._do_send_join(destination, pdu) - logger.debug("Got content: %s", content) + # logger.debug("Got content: %s", content.getvalue()) - state = [ - event_from_pdu_json(p, room_version, outlier=True) - for p in content.get("state", []) - ] + # logger.info("send_join content: %d", len(content)) - auth_chain = [ - event_from_pdu_json(p, room_version, outlier=True) - for p in content.get("auth_chain", []) - ] + content.seek(0) + + r = get_thread_resource_usage() + logger.info("Memory before state: %s", r.ru_maxrss) + + state = [] + for i, p in enumerate(ijson.items(content, "state.item")): + state.append(event_from_pdu_json(p, room_version, outlier=True)) + if i % 1000 == 999: + await self._clock.sleep(0) + + r = get_thread_resource_usage() + logger.info("Memory after state: %s", r.ru_maxrss) + + logger.info("Parsed state: %d", len(state)) + content.seek(0) + + auth_chain = [] + for i, p in enumerate(ijson.items(content, "auth_chain.item")): + auth_chain.append(event_from_pdu_json(p, room_version, outlier=True)) + if i % 1000 == 999: + await self._clock.sleep(0) - pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} + r = get_thread_resource_usage() + logger.info("Memory after: %s", r.ru_maxrss) + + logger.info("Parsed auth chain: %d", len(auth_chain)) create_event = None for e in state: @@ -704,12 +728,19 @@ class FederationClient(FederationBase): % (create_room_version,) ) - valid_pdus = await self._check_sigs_and_hash_and_fetch( - destination, - list(pdus.values()), - outlier=True, - room_version=room_version, - ) + valid_pdus = [] + + for chunk in batch_iter(itertools.chain(state, auth_chain), 1000): + logger.info("Handling next _check_sigs_and_hash_and_fetch chunk") + new_valid_pdus = await self._check_sigs_and_hash_and_fetch( + destination, + chunk, + outlier=True, + room_version=room_version, + ) + valid_pdus.extend(new_valid_pdus) + + logger.info("_check_sigs_and_hash_and_fetch done") valid_pdus_map = {p.event_id: p for p in valid_pdus} @@ -744,6 +775,8 @@ class FederationClient(FederationBase): % (auth_chain_create_events,) ) + logger.info("Returning from send_join") + return { "state": signed_state, "auth_chain": signed_auth, @@ -769,6 +802,8 @@ class FederationClient(FederationBase): if not self._is_unknown_endpoint(e): raise + raise NotImplementedError() + logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") resp = await self.transport_layer.send_join_v1( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index ada322a81e..9c0a105f36 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -244,7 +244,10 @@ class TransportLayerClient: path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = await self.client.put_json( - destination=destination, path=path, data=content + destination=destination, + path=path, + data=content, + return_string_io=True, ) return response @@ -254,7 +257,10 @@ class TransportLayerClient: path = _create_v2_path("/send_join/%s/%s", room_id, event_id) response = await self.client.put_json( - destination=destination, path=path, data=content + destination=destination, + path=path, + data=content, + return_string_io=True, ) return response diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index de1b14cde3..4064a2b859 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -78,7 +78,7 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) servers = {get_domain_from_id(u) for u in users} if not servers: @@ -270,7 +270,7 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND, ) - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) extra_servers = {get_domain_from_id(u) for u in users} servers = set(extra_servers) | set(servers) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index d82144d7fa..f134f1e234 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -103,7 +103,7 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = await self.state.get_current_users_in_room( + users = await self.store.get_users_in_room( event.room_id ) # type: Iterable[str] else: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 9d867aaf4d..e56242c63e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -552,7 +552,7 @@ class FederationHandler(BaseHandler): destination: str, room_id: str, event_id: str, - ) -> Tuple[List[EventBase], List[EventBase]]: + ) -> List[EventBase]: """Requests all of the room state at a given event from a remote homeserver. Args: @@ -573,11 +573,10 @@ class FederationHandler(BaseHandler): desired_events = set(state_event_ids + auth_event_ids) - event_map = await self._get_events_from_store_or_dest( + failed_to_fetch = await self._get_events_from_store_or_dest( destination, room_id, desired_events ) - failed_to_fetch = desired_events - event_map.keys() if failed_to_fetch: logger.warning( "Failed to fetch missing state/auth events for %s %s", @@ -585,18 +584,44 @@ class FederationHandler(BaseHandler): failed_to_fetch, ) + event_map = await self.store.get_events(state_event_ids, allow_rejected=True) + remote_state = [ event_map[e_id] for e_id in state_event_ids if e_id in event_map ] - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] - auth_chain.sort(key=lambda e: e.depth) + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = [ + (event_id, event.room_id) + for idx, event in enumerate(remote_state) + if event.room_id != room_id + ] + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned auth/state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + + if bad_events: + remote_state = [e for e in remote_state if e.room_id == room_id] - return remote_state, auth_chain + return remote_state async def _get_events_from_store_or_dest( self, destination: str, room_id: str, event_ids: Iterable[str] - ) -> Dict[str, EventBase]: + ) -> Set[str]: """Fetch events from a remote destination, checking if we already have them. Persists any events we don't already have as outliers. @@ -613,54 +638,25 @@ class FederationHandler(BaseHandler): Returns: map from event_id to event """ - fetched_events = await self.store.get_events(event_ids, allow_rejected=True) - - missing_events = set(event_ids) - fetched_events.keys() + have_events = await self.store.have_seen_events(event_ids) - if missing_events: - logger.debug( - "Fetching unknown state/auth events %s for room %s", - missing_events, - room_id, - ) - - await self._get_events_and_persist( - destination=destination, room_id=room_id, events=missing_events - ) - - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - (await self.store.get_events(missing_events, allow_rejected=True)) - ) - - # check for events which were in the wrong room. - # - # this can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B + missing_events = set(event_ids) - have_events - bad_events = [ - (event_id, event.room_id) - for event_id, event in fetched_events.items() - if event.room_id != room_id - ] + if not missing_events: + return set() - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned auth/state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) + logger.debug( + "Fetching unknown state/auth events %s for room %s", + missing_events, + room_id, + ) - del fetched_events[bad_event_id] + await self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) - return fetched_events + new_events = await self.store.have_seen_events(missing_events) + return missing_events - new_events async def _get_state_after_missing_prev_event( self, @@ -963,27 +959,23 @@ class FederationHandler(BaseHandler): # For each edge get the current state. - auth_events = {} state_events = {} events_to_state = {} for e_id in edges: - state, auth = await self._get_state_for_room( + state = await self._get_state_for_room( destination=dest, room_id=room_id, event_id=e_id, ) - auth_events.update({a.event_id: a for a in auth}) - auth_events.update({s.event_id: s for s in state}) state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state required_auth = { a_id - for event in events - + list(state_events.values()) - + list(auth_events.values()) + for event in events + list(state_events.values()) for a_id in event.auth_event_ids() } + auth_events = await self.store.get_events(required_auth, allow_rejected=True) auth_events.update( {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} ) @@ -1452,7 +1444,7 @@ class FederationHandler(BaseHandler): # room stuff after join currently doesn't work on workers. assert self.config.worker.worker_app is None - logger.debug("Joining %s to %s", joinee, room_id) + logger.info("Joining %s to %s", joinee, room_id) origin, event, room_version_obj = await self._make_and_verify_event( target_hosts, @@ -1463,6 +1455,8 @@ class FederationHandler(BaseHandler): params={"ver": KNOWN_ROOM_VERSIONS}, ) + logger.info("make_join done from %s", origin) + # This shouldn't happen, because the RoomMemberHandler has a # linearizer lock which only allows one operation per user per room # at a time - so this is just paranoia. @@ -1482,10 +1476,13 @@ class FederationHandler(BaseHandler): except ValueError: pass + logger.info("Sending join") ret = await self.federation_client.send_join( host_list, event, room_version_obj ) + logger.info("send join done") + origin = ret["origin"] state = ret["state"] auth_chain = ret["auth_chain"] @@ -1510,10 +1507,14 @@ class FederationHandler(BaseHandler): room_version=room_version_obj, ) + logger.info("Persisting auth true") + max_stream_id = await self._persist_auth_tree( origin, room_id, auth_chain, state, event, room_version_obj ) + logger.info("Persisted auth true") + # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. await self._replication.wait_for_stream_position( @@ -2166,6 +2167,8 @@ class FederationHandler(BaseHandler): ctx = await self.state_handler.compute_event_context(e) events_to_context[e.event_id] = ctx + logger.info("Computed contexts") + event_map = { e.event_id: e for e in itertools.chain(auth_events, state, [event]) } @@ -2207,6 +2210,8 @@ class FederationHandler(BaseHandler): else: logger.info("Failed to find auth event %r", e_id) + logger.info("Got missing events") + for e in itertools.chain(auth_events, state, [event]): auth_for_e = { (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] @@ -2231,6 +2236,8 @@ class FederationHandler(BaseHandler): raise events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR + logger.info("Authed events") + await self.persist_events_and_notify( room_id, [ @@ -2239,10 +2246,14 @@ class FederationHandler(BaseHandler): ], ) + logger.info("Persisted events") + new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) + logger.info("Computed context") + return await self.persist_events_and_notify( room_id, [(event, new_event_context)] ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d4684755bc..9919cccb19 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -258,7 +258,7 @@ class MessageHandler: "Getting joined members after leaving is not implemented" ) - users_with_profile = await self.state.get_current_users_in_room(room_id) + users_with_profile = await self.store.get_users_in_room_with_profiles(room_id) # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index ebbc234334..6fd1f34289 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1183,7 +1183,16 @@ class PresenceHandler(BasePresenceHandler): max_pos, deltas = await self.store.get_current_state_deltas( self._event_pos, room_max_stream_ordering ) - await self._handle_state_delta(deltas) + + # We may get multiple deltas for different rooms, but we want to + # handle them on a room by room basis, so we batch them up by + # room. + deltas_by_room: Dict[str, List[JsonDict]] = {} + for delta in deltas: + deltas_by_room.setdefault(delta["room_id"], []).append(delta) + + for room_id, deltas_for_room in deltas_by_room.items(): + await self._handle_state_delta(room_id, deltas_for_room) self._event_pos = max_pos @@ -1192,17 +1201,21 @@ class PresenceHandler(BasePresenceHandler): max_pos ) - async def _handle_state_delta(self, deltas: List[JsonDict]) -> None: - """Process current state deltas to find new joins that need to be - handled. + async def _handle_state_delta(self, room_id: str, deltas: List[JsonDict]) -> None: + """Process current state deltas for the room to find new joins that need + to be handled. """ - # A map of destination to a set of user state that they should receive - presence_destinations = {} # type: Dict[str, Set[UserPresenceState]] + + # Sets of newly joined users. Note that if the local server is + # joining a remote room for the first time we'll see both the joining + # user and all remote users as newly joined. + newly_joined_users = set() for delta in deltas: + assert room_id == delta["room_id"] + typ = delta["type"] state_key = delta["state_key"] - room_id = delta["room_id"] event_id = delta["event_id"] prev_event_id = delta["prev_event_id"] @@ -1231,72 +1244,55 @@ class PresenceHandler(BasePresenceHandler): # Ignore changes to join events. continue - # Retrieve any user presence state updates that need to be sent as a result, - # and the destinations that need to receive it - destinations, user_presence_states = await self._on_user_joined_room( - room_id, state_key - ) - - # Insert the destinations and respective updates into our destinations dict - for destination in destinations: - presence_destinations.setdefault(destination, set()).update( - user_presence_states - ) - - # Send out user presence updates for each destination - for destination, user_state_set in presence_destinations.items(): - self._federation_queue.send_presence_to_destinations( - destinations=[destination], states=user_state_set - ) - - async def _on_user_joined_room( - self, room_id: str, user_id: str - ) -> Tuple[List[str], List[UserPresenceState]]: - """Called when we detect a user joining the room via the current state - delta stream. Returns the destinations that need to be updated and the - presence updates to send to them. - - Args: - room_id: The ID of the room that the user has joined. - user_id: The ID of the user that has joined the room. - - Returns: - A tuple of destinations and presence updates to send to them. - """ - if self.is_mine_id(user_id): - # If this is a local user then we need to send their presence - # out to hosts in the room (who don't already have it) - - # TODO: We should be able to filter the hosts down to those that - # haven't previously seen the user - - remote_hosts = await self.state.get_current_hosts_in_room(room_id) + newly_joined_users.add(state_key) - # Filter out ourselves. - filtered_remote_hosts = [ - host for host in remote_hosts if host != self.server_name - ] - - state = await self.current_state_for_user(user_id) - return filtered_remote_hosts, [state] - else: - # A remote user has joined the room, so we need to: - # 1. Check if this is a new server in the room - # 2. If so send any presence they don't already have for - # local users in the room. - - # TODO: We should be able to filter the users down to those that - # the server hasn't previously seen - - # TODO: Check that this is actually a new server joining the - # room. - - remote_host = get_domain_from_id(user_id) + if not newly_joined_users: + # If nobody has joined then there's nothing to do. + return - users = await self.state.get_current_users_in_room(room_id) - user_ids = list(filter(self.is_mine_id, users)) + # We want to send: + # 1. presence states of all local users in the room to newly joined + # remote servers + # 2. presence states of newly joined users to all remote servers in + # the room. + # + # TODO: Only send presence states to remote hosts that don't already + # have them (because they already share rooms). + + # Get all the users who were already in the room, by fetching the + # current users in the room and removing the newly joined users. + users = await self.store.get_users_in_room(room_id) + prev_users = set(users) - newly_joined_users + + # Construct sets for all the local users and remote hosts that were + # already in the room + prev_local_users = [] + prev_remote_hosts = set() + for user_id in prev_users: + if self.is_mine_id(user_id): + prev_local_users.append(user_id) + else: + prev_remote_hosts.add(get_domain_from_id(user_id)) + + # Similarly, construct sets for all the local users and remote hosts + # that were *not* already in the room. Care needs to be taken with the + # calculating the remote hosts, as a host may have already been in the + # room even if there is a newly joined user from that host. + newly_joined_local_users = [] + newly_joined_remote_hosts = set() + for user_id in newly_joined_users: + if self.is_mine_id(user_id): + newly_joined_local_users.append(user_id) + else: + host = get_domain_from_id(user_id) + if host not in prev_remote_hosts: + newly_joined_remote_hosts.add(host) - states_d = await self.current_state_for_users(user_ids) + # Send presence states of all local users in the room to newly joined + # remote servers. (We actually only send states for local users already + # in the room, as we'll send states for newly joined local users below.) + if prev_local_users and newly_joined_remote_hosts: + local_states = await self.current_state_for_users(prev_local_users) # Filter out old presence, i.e. offline presence states where # the user hasn't been active for a week. We can change this @@ -1306,13 +1302,27 @@ class PresenceHandler(BasePresenceHandler): now = self.clock.time_msec() states = [ state - for state in states_d.values() + for state in local_states.values() if state.state != PresenceState.OFFLINE or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 or state.status_msg is not None ] - return [remote_host], states + self._federation_queue.send_presence_to_destinations( + destinations=newly_joined_remote_hosts, + states=states, + ) + + # Send presence states of newly joined users to all remote servers in + # the room + if newly_joined_local_users and ( + prev_remote_hosts or newly_joined_remote_hosts + ): + local_states = await self.current_state_for_users(newly_joined_local_users) + self._federation_queue.send_presence_to_destinations( + destinations=prev_remote_hosts | newly_joined_remote_hosts, + states=list(local_states.values()), + ) def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> bool: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 88e9d0e2fe..0b805af60c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1327,7 +1327,7 @@ class RoomShutdownHandler: new_room_id = None logger.info("Shutting down room %r", room_id) - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) kicked_users = [] failed_to_kick_users = [] for user_id in users: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a9a3ee05c3..0fcc1532da 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1190,7 +1190,7 @@ class SyncHandler: # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: - joined_users = await self.state.get_current_users_in_room(room_id) + joined_users = await self.store.get_users_in_room(room_id) newly_joined_or_invited_users.update(joined_users) # TODO: Check that these users are actually new, i.e. either they @@ -1206,7 +1206,7 @@ class SyncHandler: # Now find users that we no longer track for room_id in newly_left_rooms: - left_users = await self.state.get_current_users_in_room(room_id) + left_users = await self.store.get_users_in_room(room_id) newly_left_users.update(left_users) # Remove any users that we still share a room with. @@ -1361,7 +1361,7 @@ class SyncHandler: extra_users_ids = set(newly_joined_or_invited_users) for room_id in newly_joined_rooms: - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index bb837b7b19..6db1aece35 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -154,6 +154,7 @@ async def _handle_json_response( request: MatrixFederationRequest, response: IResponse, start_ms: int, + return_string_io=False, ) -> JsonDict: """ Reads the JSON body of a response, with a timeout @@ -175,12 +176,12 @@ async def _handle_json_response( d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) - def parse(_len: int): - return json_decoder.decode(buf.getvalue()) + await make_deferred_yieldable(d) - d.addCallback(parse) - - body = await make_deferred_yieldable(d) + if return_string_io: + body = buf + else: + body = json_decoder.decode(buf.getvalue()) except BodyExceededMaxSize as e: # The response was too big. logger.warning( @@ -225,12 +226,13 @@ async def _handle_json_response( time_taken_secs = reactor.seconds() - start_ms / 1000 logger.info( - "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s", + "{%s} [%s] Completed request: %d %s in %.2f secs got %dB - %s %s", request.txn_id, request.destination, response.code, response.phrase.decode("ascii", errors="replace"), time_taken_secs, + len(buf.getvalue()), request.method, request.uri.decode("ascii"), ) @@ -683,6 +685,7 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, + return_string_io=False, ) -> Union[JsonDict, list]: """Sends the specified json data using PUT @@ -757,7 +760,12 @@ class MatrixFederationHttpClient: _sec_timeout = self.default_timeout body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms + self.reactor, + _sec_timeout, + request, + response, + start_ms, + return_string_io=return_string_io, ) return body diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 31b7b3c256..fef2846669 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -535,6 +535,13 @@ class ReactorLastSeenMetric: REGISTRY.register(ReactorLastSeenMetric()) +# The minimum time in seconds between GCs for each generation, regardless of the current GC +# thresholds and counts. +MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0) + +# The time (in seconds since the epoch) of the last time we did a GC for each generation. +_last_gc = [0.0, 0.0, 0.0] + def runUntilCurrentTimer(reactor, func): @functools.wraps(func) @@ -575,11 +582,16 @@ def runUntilCurrentTimer(reactor, func): return ret # Check if we need to do a manual GC (since its been disabled), and do - # one if necessary. + # one if necessary. Note we go in reverse order as e.g. a gen 1 GC may + # promote an object into gen 2, and we don't want to handle the same + # object multiple times. threshold = gc.get_threshold() counts = gc.get_count() for i in (2, 1, 0): - if threshold[i] < counts[i]: + # We check if we need to do one based on a straightforward + # comparison between the threshold and count. We also do an extra + # check to make sure that we don't a GC too often. + if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]: if i == 0: logger.debug("Collecting gc %d", i) else: @@ -589,6 +601,8 @@ def runUntilCurrentTimer(reactor, func): unreachable = gc.collect(i) end = time.time() + _last_gc[i] = end + gc_time.labels(i).observe(end - start) gc_unreachable.labels(i).set(unreachable) @@ -615,6 +629,7 @@ try: except AttributeError: pass + __all__ = [ "MetricsResource", "generate_latest", diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py new file mode 100644 index 0000000000..07ed1d2453 --- /dev/null +++ b/synapse/metrics/jemalloc.py @@ -0,0 +1,191 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ctypes +import logging +import os +import re +from typing import Optional + +from synapse.metrics import REGISTRY, GaugeMetricFamily + +logger = logging.getLogger(__name__) + + +def _setup_jemalloc_stats(): + """Checks to see if jemalloc is loaded, and hooks up a collector to record + statistics exposed by jemalloc. + """ + + # Try to find the loaded jemalloc shared library, if any. We need to + # introspect into what is loaded, rather than loading whatever is on the + # path, as if we load a *different* jemalloc version things will seg fault. + + # We look in `/proc/self/maps`, which only exists on linux. + if not os.path.exists("/proc/self/maps"): + logger.debug("Not looking for jemalloc as no /proc/self/maps exist") + return + + # We're looking for a path at the end of the line that includes + # "libjemalloc". + regex = re.compile(r"/\S+/libjemalloc.*$") + + jemalloc_path = None + with open("/proc/self/maps") as f: + for line in f: + match = regex.search(line.strip()) + if match: + jemalloc_path = match.group() + + if not jemalloc_path: + # No loaded jemalloc was found. + logger.debug("jemalloc not found") + return + + jemalloc = ctypes.CDLL(jemalloc_path) + + def _mallctl( + name: str, read: bool = True, write: Optional[int] = None + ) -> Optional[int]: + """Wrapper around `mallctl` for reading and writing integers to + jemalloc. + + Args: + name: The name of the option to read from/write to. + read: Whether to try and read the value. + write: The value to write, if given. + + Returns: + The value read if `read` is True, otherwise None. + + Raises: + An exception if `mallctl` returns a non-zero error code. + """ + + input_var = None + input_var_ref = None + input_len_ref = None + if read: + input_var = ctypes.c_size_t(0) + input_len = ctypes.c_size_t(ctypes.sizeof(input_var)) + + input_var_ref = ctypes.byref(input_var) + input_len_ref = ctypes.byref(input_len) + + write_var_ref = None + write_len = ctypes.c_size_t(0) + if write is not None: + write_var = ctypes.c_size_t(write) + write_len = ctypes.c_size_t(ctypes.sizeof(write_var)) + + write_var_ref = ctypes.byref(write_var) + + # The interface is: + # + # int mallctl( + # const char *name, + # void *oldp, + # size_t *oldlenp, + # void *newp, + # size_t newlen + # ) + # + # Where oldp/oldlenp is a buffer where the old value will be written to + # (if not null), and newp/newlen is the buffer with the new value to set + # (if not null). Note that they're all references *except* newlen. + result = jemalloc.mallctl( + name.encode("ascii"), + input_var_ref, + input_len_ref, + write_var_ref, + write_len, + ) + + if result != 0: + raise Exception("Failed to call mallctl") + + if input_var is None: + return None + + return input_var.value + + def _jemalloc_refresh_stats() -> None: + """Request that jemalloc updates its internal statistics. This needs to + be called before querying for stats, otherwise it will return stale + values. + """ + try: + _mallctl("epoch", read=False, write=1) + except Exception: + pass + + class JemallocCollector: + """Metrics for internal jemalloc stats.""" + + def collect(self): + _jemalloc_refresh_stats() + + g = GaugeMetricFamily( + "jemalloc_stats_app_memory_bytes", + "The stats reported by jemalloc", + labels=["type"], + ) + + # Read the relevant global stats from jemalloc. Note that these may + # not be accurate if python is configured to use its internal small + # object allocator (which is on by default, disable by setting the + # env `PYTHONMALLOC=malloc`). + # + # See the jemalloc manpage for details about what each value means, + # roughly: + # - allocated ─ Total number of bytes allocated by the app + # - active ─ Total number of bytes in active pages allocated by + # the application, this is bigger than `allocated`. + # - resident ─ Maximum number of bytes in physically resident data + # pages mapped by the allocator, comprising all pages dedicated + # to allocator metadata, pages backing active allocations, and + # unused dirty pages. This is bigger than `active`. + # - mapped ─ Total number of bytes in active extents mapped by the + # allocator. + # - metadata ─ Total number of bytes dedicated to jemalloc + # metadata. + for t in ( + "allocated", + "active", + "resident", + "mapped", + "metadata", + ): + try: + value = _mallctl(f"stats.{t}") + except Exception: + # There was an error fetching the value, skip. + continue + + g.add_metric([t], value=value) + + yield g + + REGISTRY.register(JemallocCollector()) + + logger.debug("Added jemalloc stats") + + +def setup_jemalloc_stats(): + """Try to setup jemalloc stats, if jemalloc is loaded.""" + + try: + _setup_jemalloc_stats() + except Exception as e: + logger.info("Failed to setup collector to record jemalloc stats: %s", e) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 2de946f464..d58eeeaa74 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -116,6 +116,8 @@ CONDITIONAL_REQUIREMENTS = { # hiredis is not a *strict* dependency, but it makes things much faster. # (if it is not installed, we fall back to slow code.) "redis": ["txredisapi>=1.4.7", "hiredis"], + # Required to use experimental `caches.track_memory_usage` config option. + "cache_memory": ["pympler"], } ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index f648678b09..e19e9ef5c7 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.util import json_decoder +from synapse.util.async_helpers import yieldable_gather_results logger = logging.getLogger(__name__) @@ -213,7 +214,13 @@ class RemoteKey(DirectServeJsonResource): # If there is a cache miss, request the missing keys, then recurse (and # ensure the result is sent). if cache_misses and query_remote_on_cache_miss: - await self.fetcher.get_keys(cache_misses) + await yieldable_gather_results( + lambda t: self.fetcher.get_keys(*t), + ( + (server_name, list(keys), 0) + for server_name, keys in cache_misses.items() + ), + ) await self.query_keys(request, query, query_remote_on_cache_miss=False) else: signed_keys = [] diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index b3bd92d37c..a1770f620e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -213,19 +213,23 @@ class StateHandler: return ret.state async def get_current_users_in_room( - self, room_id: str, latest_event_ids: Optional[List[str]] = None + self, room_id: str, latest_event_ids: List[str] ) -> Dict[str, ProfileInfo]: """ Get the users who are currently in a room. + Note: This is much slower than using the equivalent method + `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`, + so this should only be used when wanting the users at a particular point + in the room. + Args: room_id: The ID of the room. latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. Returns: Dictionary of user IDs to their profileinfo. """ - if not latest_event_ids: - latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_users_in_room") diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 6b68d8720c..3d98d3f5f8 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -69,6 +69,7 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) + self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,)) self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 2a8532f8c1..5fc3bb5a7d 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -205,8 +205,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]: sql = """ - SELECT user_id, display_name, avatar_url FROM room_memberships - WHERE room_id = ? AND membership = ? + SELECT state_key, display_name, avatar_url FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? """ txn.execute(sql, (room_id, Membership.JOIN)) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 7a082fdd21..a6bfb4902a 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -142,8 +142,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): batch_size (int): Maximum number of state events to process per cycle. """ - state = self.hs.get_state_handler() - # If we don't have progress filed, delete everything. if not progress: await self.delete_all_from_user_dir() @@ -197,7 +195,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): room_id ) - users_with_profile = await state.get_current_users_in_room(room_id) + users_with_profile = await self.get_users_in_room_with_profiles(room_id) user_ids = set(users_with_profile) # Update each user in the user directory. diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 46af7fa473..ca36f07c20 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -24,6 +24,11 @@ from synapse.config.cache import add_resizable_cache logger = logging.getLogger(__name__) + +# Whether to track estimated memory usage of the LruCaches. +TRACK_MEMORY_USAGE = False + + caches_by_name = {} # type: Dict[str, Sized] collectors_by_name = {} # type: Dict[str, CacheMetric] @@ -32,6 +37,11 @@ cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"]) cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"]) cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"]) +cache_memory_usage = Gauge( + "synapse_util_caches_cache_size_bytes", + "Estimated memory usage of the caches", + ["name"], +) response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"]) response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"]) @@ -52,6 +62,7 @@ class CacheMetric: hits = attr.ib(default=0) misses = attr.ib(default=0) evicted_size = attr.ib(default=0) + memory_usage = attr.ib(default=None) def inc_hits(self): self.hits += 1 @@ -62,6 +73,19 @@ class CacheMetric: def inc_evictions(self, size=1): self.evicted_size += size + def inc_memory_usage(self, memory: int): + if self.memory_usage is None: + self.memory_usage = 0 + + self.memory_usage += memory + + def dec_memory_usage(self, memory: int): + self.memory_usage -= memory + + def clear_memory_usage(self): + if self.memory_usage is not None: + self.memory_usage = 0 + def describe(self): return [] @@ -81,6 +105,13 @@ class CacheMetric: cache_total.labels(self._cache_name).set(self.hits + self.misses) if getattr(self._cache, "max_size", None): cache_max_size.labels(self._cache_name).set(self._cache.max_size) + + if TRACK_MEMORY_USAGE: + # self.memory_usage can be None if nothing has been inserted + # into the cache yet. + cache_memory_usage.labels(self._cache_name).set( + self.memory_usage or 0 + ) if self._collect_callback: self._collect_callback() except Exception as e: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 10b0ec6b75..1be675e014 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -32,9 +32,36 @@ from typing import ( from typing_extensions import Literal from synapse.config import cache as cache_config +from synapse.util import caches from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.treecache import TreeCache +try: + from pympler.asizeof import Asizer + + def _get_size_of(val: Any, *, recurse=True) -> int: + """Get an estimate of the size in bytes of the object. + + Args: + val: The object to size. + recurse: If true will include referenced values in the size, + otherwise only sizes the given object. + """ + # Ignore singleton values when calculating memory usage. + if val in ((), None, ""): + return 0 + + sizer = Asizer() + sizer.exclude_refs((), None, "") + return sizer.asizeof(val, limit=100 if recurse else 0) + + +except ImportError: + + def _get_size_of(val: Any, *, recurse=True) -> int: + return 0 + + # Function type: the type used for invalidation callbacks FT = TypeVar("FT", bound=Callable[..., Any]) @@ -56,7 +83,7 @@ def enumerate_leaves(node, depth): class _Node: - __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] + __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"] def __init__( self, @@ -84,6 +111,16 @@ class _Node: self.add_callbacks(callbacks) + self.memory = 0 + if caches.TRACK_MEMORY_USAGE: + self.memory = ( + _get_size_of(key) + + _get_size_of(value) + + _get_size_of(self.callbacks, recurse=False) + + _get_size_of(self, recurse=False) + ) + self.memory += _get_size_of(self.memory, recurse=False) + def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None: """Add to stored list of callbacks, removing duplicates.""" @@ -233,6 +270,9 @@ class LruCache(Generic[KT, VT]): if size_callback: cached_cache_len[0] += size_callback(node.value) + if caches.TRACK_MEMORY_USAGE and metrics: + metrics.inc_memory_usage(node.memory) + def move_node_to_front(node): prev_node = node.prev_node next_node = node.next_node @@ -258,6 +298,9 @@ class LruCache(Generic[KT, VT]): node.run_and_clear_callbacks() + if caches.TRACK_MEMORY_USAGE and metrics: + metrics.dec_memory_usage(node.memory) + return deleted_len @overload @@ -373,6 +416,9 @@ class LruCache(Generic[KT, VT]): if size_callback: cached_cache_len[0] = 0 + if caches.TRACK_MEMORY_USAGE and metrics: + metrics.clear_memory_usage() + @synchronized def cache_contains(key: KT) -> bool: return key in cache diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index ce330e79cc..1ffab709fc 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -729,7 +729,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=["server2"], states={expected_state} + destinations={"server2"}, states=[expected_state] ) # @@ -740,7 +740,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self._add_new_user(room_id, "@bob:server3") self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=["server3"], states={expected_state} + destinations={"server3"}, states=[expected_state] ) def test_remote_gets_presence_when_local_user_joins(self): @@ -788,14 +788,8 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.presence_handler.current_state_for_user("@test2:server") ) self.assertEqual(expected_state.state, PresenceState.ONLINE) - self.assertEqual( - self.federation_sender.send_presence_to_destinations.call_count, 2 - ) - self.federation_sender.send_presence_to_destinations.assert_any_call( - destinations=["server3"], states={expected_state} - ) - self.federation_sender.send_presence_to_destinations.assert_any_call( - destinations=["server2"], states={expected_state} + self.federation_sender.send_presence_to_destinations.assert_called_once_with( + destinations={"server2", "server3"}, states=[expected_state] ) def _add_new_user(self, room_id, user_id): |