diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 638e01c1b2..59918d789e 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -37,6 +37,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.metrics.jemalloc import setup_jemalloc_stats
from synapse.util.async_helpers import Linearizer
from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit
@@ -115,6 +116,7 @@ def start_reactor(
def run():
logger.info("Running")
+ setup_jemalloc_stats()
change_resource_limit(soft_file_limit)
if gc_thresholds:
gc.set_threshold(*gc_thresholds)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 1a15ceee81..f730cdbd78 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -454,6 +454,10 @@ def start(config_options):
config.server.update_user_directory = False
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
+ synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
+
+ if config.server.gc_seconds:
+ synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
hs = GenericWorkerServer(
config.server_name,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8e78134bbe..b2501ee4d7 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -341,6 +341,10 @@ def setup(config_options):
sys.exit(0)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
+ synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
+
+ if config.server.gc_seconds:
+ synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
hs = SynapseHomeServer(
config.server_name,
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 41b9b3f51f..91165ee1ce 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -17,6 +17,8 @@ import re
import threading
from typing import Callable, Dict
+from synapse.python_dependencies import DependencyException, check_requirements
+
from ._base import Config, ConfigError
# The prefix for all cache factor-related environment variables
@@ -189,6 +191,15 @@ class CacheConfig(Config):
)
self.cache_factors[cache] = factor
+ self.track_memory_usage = cache_config.get("track_memory_usage", False)
+ if self.track_memory_usage:
+ try:
+ check_requirements("cache_memory")
+ except DependencyException as e:
+ raise ConfigError(
+ e.message # noqa: B306, DependencyException.message is a property
+ )
+
# Resize all caches (if necessary) with the new factors we've loaded
self.resize_all_caches()
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 21ca7b33e3..c290a35a92 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,7 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
-from typing import Any, Dict, Iterable, List, Optional, Set
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
import attr
import yaml
@@ -572,6 +572,7 @@ class ServerConfig(Config):
_warn_if_webclient_configured(self.listeners)
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
+ self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None))
@attr.s
class LimitRemoteRoomsConfig:
@@ -917,6 +918,16 @@ class ServerConfig(Config):
#
#gc_thresholds: [700, 10, 10]
+ # The minimum time in seconds between each GC for a generation, regardless of
+ # the GC thresholds. This ensures that we don't do GC too frequently.
+ #
+ # A value of `[1s, 10s, 30s]` indicates that a second must pass between consecutive
+ # generation 0 GCs, etc.
+ #
+ # Defaults to `[1s, 10s, 30s]`.
+ #
+ #gc_min_interval: [0.5s, 30s, 1m]
+
# Set the limit on the returned events in the timeline in the get
# and sync operations. The default value is 100. -1 means no upper limit.
#
@@ -1305,6 +1316,24 @@ class ServerConfig(Config):
help="Turn on the twisted telnet manhole service on the given port.",
)
+ def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]:
+ """Reads the three durations for the GC min interval option, returning seconds."""
+ if durations is None:
+ return None
+
+ try:
+ if len(durations) != 3:
+ raise ValueError()
+ return (
+ self.parse_duration(durations[0]) / 1000,
+ self.parse_duration(durations[1]) / 1000,
+ self.parse_duration(durations[2]) / 1000,
+ )
+ except Exception:
+ raise ConfigError(
+ "Value of `gc_min_interval` must be a list of three durations if set"
+ )
+
def is_threepid_reserved(reserved_threepids, threepid):
"""Check the threepid against the reserved threepid config
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5f18ef7748..9efe13606a 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -16,8 +16,7 @@
import abc
import logging
import urllib
-from collections import defaultdict
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
import attr
from signedjson.key import (
@@ -42,17 +41,18 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config.key import TrustedKeyServer
+from synapse.events import EventBase
+from synapse.events.utils import prune_event_dict
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
- preserve_fn,
run_in_background,
)
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import yieldable_gather_results
-from synapse.util.metrics import Measure
+from synapse.util.async_helpers import Linearizer, yieldable_gather_results
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
@@ -74,8 +74,6 @@ class VerifyJsonRequest:
minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
- request_name: The name of the request.
-
key_ids: The set of key_ids to that could be used to verify the JSON object
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
@@ -88,20 +86,93 @@ class VerifyJsonRequest:
"""
server_name = attr.ib(type=str)
- json_object = attr.ib(type=JsonDict)
+ json_object_callback = attr.ib(type=Callable[[], JsonDict])
minimum_valid_until_ts = attr.ib(type=int)
- request_name = attr.ib(type=str)
- key_ids = attr.ib(init=False, type=List[str])
- key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
+ key_ids = attr.ib(type=List[str])
- def __attrs_post_init__(self):
- self.key_ids = signature_ids(self.json_object, self.server_name)
+ @staticmethod
+ def from_json_object(
+ server_name: str, minimum_valid_until_ms: int, json_object: JsonDict
+ ):
+ key_ids = signature_ids(json_object, server_name)
+ return VerifyJsonRequest(
+ server_name, lambda: json_object, minimum_valid_until_ms, key_ids
+ )
+
+ @staticmethod
+ def from_event(
+ server_name: str,
+ minimum_valid_until_ms: int,
+ event: EventBase,
+ ):
+ key_ids = list(event.signatures.get(server_name, []))
+ return VerifyJsonRequest(
+ server_name,
+ lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
+ minimum_valid_until_ms,
+ key_ids,
+ )
class KeyLookupError(ValueError):
pass
+@attr.s(slots=True)
+class _QueueValue:
+ server_name = attr.ib(type=str)
+ minimum_valid_until_ts = attr.ib(type=int)
+ key_ids = attr.ib(type=List[str])
+
+
+class _Queue:
+ def __init__(self, name, clock, process_items):
+ self._name = name
+ self._clock = clock
+ self._is_processing = False
+ self._next_values = []
+
+ self.process_items = process_items
+
+ async def add_to_queue(self, value: _QueueValue) -> Dict[str, FetchKeyResult]:
+ d = defer.Deferred()
+ self._next_values.append((value, d))
+
+ if not self._is_processing:
+ run_as_background_process(self._name, self._unsafe_process)
+
+ return await make_deferred_yieldable(d)
+
+ async def _unsafe_process(self):
+ try:
+ if self._is_processing:
+ return
+
+ self._is_processing = True
+
+ while self._next_values:
+ # We purposefully defer to the next loop.
+ await self._clock.sleep(0)
+
+ next_values = self._next_values
+ self._next_values = []
+
+ try:
+ values = [value for value, _ in next_values]
+ results = await self.process_items(values)
+
+ for value, deferred in next_values:
+ with PreserveLoggingContext():
+ deferred.callback(results.get(value.server_name, {}))
+
+ except Exception as e:
+ for _, deferred in next_values:
+ deferred.errback(e)
+
+ finally:
+ self._is_processing = False
+
+
class Keyring:
def __init__(
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
@@ -116,12 +187,7 @@ class Keyring:
)
self._key_fetchers = key_fetchers
- # map from server name to Deferred. Has an entry for each server with
- # an ongoing key download; the Deferred completes once the download
- # completes.
- #
- # These are regular, logcontext-agnostic Deferreds.
- self.key_downloads = {} # type: Dict[str, defer.Deferred]
+ self._server_queue = Linearizer("keyring_server")
def verify_json_for_server(
self,
@@ -130,365 +196,150 @@ class Keyring:
validity_time: int,
request_name: str,
) -> defer.Deferred:
- """Verify that a JSON object has been signed by a given server
-
- Args:
- server_name: name of the server which must have signed this object
-
- json_object: object to be checked
-
- validity_time: timestamp at which we require the signing key to
- be valid. (0 implies we don't care)
-
- request_name: an identifier for this json object (eg, an event id)
- for logging.
-
- Returns:
- Deferred[None]: completes if the the object was correctly signed, otherwise
- errbacks with an error
- """
- req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
- requests = (req,)
- return make_deferred_yieldable(self._verify_objects(requests)[0])
+ request = VerifyJsonRequest.from_json_object(
+ server_name,
+ validity_time,
+ json_object,
+ )
+ return defer.ensureDeferred(self._verify_object(request))
def verify_json_objects_for_server(
self, server_and_json: Iterable[Tuple[str, dict, int, str]]
) -> List[defer.Deferred]:
- """Bulk verifies signatures of json objects, bulk fetching keys as
- necessary.
-
- Args:
- server_and_json:
- Iterable of (server_name, json_object, validity_time, request_name)
- tuples.
-
- validity_time is a timestamp at which the signing key must be
- valid.
-
- request_name is an identifier for this json object (eg, an event id)
- for logging.
-
- Returns:
- List<Deferred[None]>: for each input triplet, a deferred indicating success
- or failure to verify each json object's signature for the given
- server_name. The deferreds run their callbacks in the sentinel
- logcontext.
- """
- return self._verify_objects(
- VerifyJsonRequest(server_name, json_object, validity_time, request_name)
- for server_name, json_object, validity_time, request_name in server_and_json
- )
-
- def _verify_objects(
- self, verify_requests: Iterable[VerifyJsonRequest]
- ) -> List[defer.Deferred]:
- """Does the work of verify_json_[objects_]for_server
-
-
- Args:
- verify_requests: Iterable of verification requests.
-
- Returns:
- List<Deferred[None]>: for each input item, a deferred indicating success
- or failure to verify each json object's signature for the given
- server_name. The deferreds run their callbacks in the sentinel
- logcontext.
- """
- # a list of VerifyJsonRequests which are awaiting a key lookup
- key_lookups = []
- handle = preserve_fn(_handle_key_deferred)
-
- def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
- """Process an entry in the request list
-
- Adds a key request to key_lookups, and returns a deferred which
- will complete or fail (in the sentinel context) when verification completes.
- """
- if not verify_request.key_ids:
- return defer.fail(
- SynapseError(
- 400,
- "Not signed by %s" % (verify_request.server_name,),
- Codes.UNAUTHORIZED,
- )
+ return [
+ defer.ensureDeferred(
+ run_in_background(
+ self._verify_object,
+ VerifyJsonRequest.from_json_object(
+ server_name,
+ validity_time,
+ json_object,
+ ),
)
-
- logger.debug(
- "Verifying %s for %s with key_ids %s, min_validity %i",
- verify_request.request_name,
- verify_request.server_name,
- verify_request.key_ids,
- verify_request.minimum_valid_until_ts,
)
+ for server_name, json_object, validity_time, request_name in server_and_json
+ ]
- # add the key request to the queue, but don't start it off yet.
- key_lookups.append(verify_request)
-
- # now run _handle_key_deferred, which will wait for the key request
- # to complete and then do the verification.
- #
- # We want _handle_key_request to log to the right context, so we
- # wrap it with preserve_fn (aka run_in_background)
- return handle(verify_request)
-
- results = [process(r) for r in verify_requests]
-
- if key_lookups:
- run_in_background(self._start_key_lookups, key_lookups)
-
- return results
-
- async def _start_key_lookups(
- self, verify_requests: List[VerifyJsonRequest]
- ) -> None:
- """Sets off the key fetches for each verify request
-
- Once each fetch completes, verify_request.key_ready will be resolved.
-
- Args:
- verify_requests:
- """
-
- try:
- # map from server name to a set of outstanding request ids
- server_to_request_ids = {} # type: Dict[str, Set[int]]
-
- for verify_request in verify_requests:
- server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids.setdefault(server_name, set()).add(request_id)
-
- # Wait for any previous lookups to complete before proceeding.
- await self.wait_for_previous_lookups(server_to_request_ids.keys())
-
- # take out a lock on each of the servers by sticking a Deferred in
- # key_downloads
- for server_name in server_to_request_ids.keys():
- self.key_downloads[server_name] = defer.Deferred()
- logger.debug("Got key lookup lock on %s", server_name)
-
- # When we've finished fetching all the keys for a given server_name,
- # drop the lock by resolving the deferred in key_downloads.
- def drop_server_lock(server_name):
- d = self.key_downloads.pop(server_name)
- d.callback(None)
-
- def lookup_done(res, verify_request):
- server_name = verify_request.server_name
- server_requests = server_to_request_ids[server_name]
- server_requests.remove(id(verify_request))
-
- # if there are no more requests for this server, we can drop the lock.
- if not server_requests:
- logger.debug("Releasing key lookup lock on %s", server_name)
- drop_server_lock(server_name)
-
- return res
-
- for verify_request in verify_requests:
- verify_request.key_ready.addBoth(lookup_done, verify_request)
-
- # Actually start fetching keys.
- self._get_server_verify_keys(verify_requests)
- except Exception:
- logger.exception("Error starting key lookups")
-
- async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
- """Waits for any previous key lookups for the given servers to finish.
-
- Args:
- server_names: list of servers which we want to look up
-
- Returns:
- Resolves once all key lookups for the given servers have
- completed. Follows the synapse rules of logcontext preservation.
- """
- loop_count = 1
- while True:
- wait_on = [
- (server_name, self.key_downloads[server_name])
- for server_name in server_names
- if server_name in self.key_downloads
- ]
- if not wait_on:
- break
- logger.info(
- "Waiting for existing lookups for %s to complete [loop %i]",
- [w[0] for w in wait_on],
- loop_count,
+ def verify_events_for_server(
+ self, server_and_json: Iterable[Tuple[str, EventBase, int]]
+ ) -> List[defer.Deferred]:
+ return [
+ run_in_background(
+ self._verify_object,
+ VerifyJsonRequest.from_event(
+ server_name,
+ validity_time,
+ event,
+ ),
)
- with PreserveLoggingContext():
- await defer.DeferredList((w[1] for w in wait_on))
-
- loop_count += 1
-
- def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
- """Tries to find at least one key for each verify request
-
- For each verify_request, verify_request.key_ready is called back with
- params (server_name, key_id, VerifyKey) if a key is found, or errbacked
- with a SynapseError if none of the keys are found.
+ for server_name, event, validity_time in server_and_json
+ ]
+
+ async def _verify_object(self, verify_request: VerifyJsonRequest):
+ # TODO: Use a batching thing.
+ with (await self._server_queue.queue(verify_request.server_name)):
+ found_keys: Dict[str, FetchKeyResult] = {}
+ missing_key_ids = set(verify_request.key_ids)
+ for fetcher in self._key_fetchers:
+ if not missing_key_ids:
+ break
+
+ keys = await fetcher.get_keys(
+ verify_request.server_name,
+ list(missing_key_ids),
+ verify_request.minimum_valid_until_ts,
+ )
- Args:
- verify_requests: list of verify requests
- """
+ for key_id, key in keys.items():
+ if not key:
+ continue
- remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
+ if key.valid_until_ts < verify_request.minimum_valid_until_ts:
+ continue
- async def do_iterations():
- try:
- with Measure(self.clock, "get_server_verify_keys"):
- for f in self._key_fetchers:
- if not remaining_requests:
- return
- await self._attempt_key_fetches_with_fetcher(
- f, remaining_requests
- )
-
- # look for any requests which weren't satisfied
- while remaining_requests:
- verify_request = remaining_requests.pop()
- rq_str = (
- "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
- % (
- verify_request.server_name,
- verify_request.key_ids,
- verify_request.minimum_valid_until_ts,
- )
- )
-
- # If we run the errback immediately, it may cancel our
- # loggingcontext while we are still in it, so instead we
- # schedule it for the next time round the reactor.
- #
- # (this also ensures that we don't get a stack overflow if we
- # has a massive queue of lookups waiting for this server).
- self.clock.call_later(
- 0,
- verify_request.key_ready.errback,
- SynapseError(
- 401,
- "Failed to find any key to satisfy %s" % (rq_str,),
- Codes.UNAUTHORIZED,
- ),
- )
- except Exception as err:
- # we don't really expect to get here, because any errors should already
- # have been caught and logged. But if we do, let's log the error and make
- # sure that all of the deferreds are resolved.
- logger.error("Unexpected error in _get_server_verify_keys: %s", err)
- with PreserveLoggingContext():
- for verify_request in remaining_requests:
- if not verify_request.key_ready.called:
- verify_request.key_ready.errback(err)
-
- run_in_background(do_iterations)
-
- async def _attempt_key_fetches_with_fetcher(
- self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
- ):
- """Use a key fetcher to attempt to satisfy some key requests
+ existing_key = found_keys.get(key_id)
+ if existing_key:
+ if key.valid_until_ts <= existing_key.valid_until_ts:
+ continue
- Args:
- fetcher: fetcher to use to fetch the keys
- remaining_requests: outstanding key requests.
- Any successfully-completed requests will be removed from the list.
- """
- # The keys to fetch.
- # server_name -> key_id -> min_valid_ts
- missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
+ found_keys[key_id] = key
- for verify_request in remaining_requests:
- # any completed requests should already have been removed
- assert not verify_request.key_ready.called
- keys_for_server = missing_keys[verify_request.server_name]
+ missing_key_ids.difference_update(found_keys)
- for key_id in verify_request.key_ids:
- # If we have several requests for the same key, then we only need to
- # request that key once, but we should do so with the greatest
- # min_valid_until_ts of the requests, so that we can satisfy all of
- # the requests.
- keys_for_server[key_id] = max(
- keys_for_server.get(key_id, -1),
- verify_request.minimum_valid_until_ts,
+ if missing_key_ids:
+ raise SynapseError(
+ 400,
+ "Missing keys for %s: %s"
+ % (verify_request.server_name, missing_key_ids),
+ Codes.UNAUTHORIZED,
)
- results = await fetcher.get_keys(missing_keys)
-
- completed = []
- for verify_request in remaining_requests:
- server_name = verify_request.server_name
-
- # see if any of the keys we got this time are sufficient to
- # complete this VerifyJsonRequest.
- result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
- fetch_key_result = result_keys.get(key_id)
- if not fetch_key_result:
- # we didn't get a result for this key
- continue
-
- if (
- fetch_key_result.valid_until_ts
- < verify_request.minimum_valid_until_ts
- ):
- # key was not valid at this point
- continue
-
- # we have a valid key for this request. If we run the callback
- # immediately, it may cancel our loggingcontext while we are still in
- # it, so instead we schedule it for the next time round the reactor.
- #
- # (this also ensures that we don't get a stack overflow if we had
- # a massive queue of lookups waiting for this server).
- logger.debug(
- "Found key %s:%s for %s",
- server_name,
- key_id,
- verify_request.request_name,
- )
- self.clock.call_later(
- 0,
- verify_request.key_ready.callback,
- (server_name, key_id, fetch_key_result.verify_key),
- )
- completed.append(verify_request)
- break
-
- remaining_requests.difference_update(completed)
+ verify_key = found_keys[key_id].verify_key
+ try:
+ json_object = verify_request.json_object_callback()
+ verify_signed_json(
+ json_object,
+ verify_request.server_name,
+ verify_key,
+ )
+ except SignatureVerifyException as e:
+ logger.debug(
+ "Error verifying signature for %s:%s:%s with key %s: %s",
+ verify_request.server_name,
+ verify_key.alg,
+ verify_key.version,
+ encode_verify_key_base64(verify_key),
+ str(e),
+ )
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s: %s"
+ % (
+ verify_request.server_name,
+ verify_key.alg,
+ verify_key.version,
+ str(e),
+ ),
+ Codes.UNAUTHORIZED,
+ )
class KeyFetcher(metaclass=abc.ABCMeta):
- @abc.abstractmethod
+ def __init__(self, hs: "HomeServer"):
+ self._queue = _Queue(self.__class__.__name__, hs.get_clock(), self._fetch_keys)
+
async def get_keys(
- self, keys_to_fetch: Dict[str, Dict[str, int]]
- ) -> Dict[str, Dict[str, FetchKeyResult]]:
- """
- Args:
- keys_to_fetch:
- the keys to be fetched. server_name -> key_id -> min_valid_ts
+ self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+ ) -> Dict[str, FetchKeyResult]:
+ return await self._queue.add_to_queue(
+ _QueueValue(
+ server_name=server_name,
+ key_ids=key_ids,
+ minimum_valid_until_ts=minimum_valid_until_ts,
+ )
+ )
- Returns:
- Map from server_name -> key_id -> FetchKeyResult
- """
- raise NotImplementedError
+ @abc.abstractmethod
+ async def _fetch_keys(
+ self, keys_to_fetch: List[_QueueValue]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
+ pass
class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store"""
def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastore()
+ super().__init__(hs)
- async def get_keys(
- self, keys_to_fetch: Dict[str, Dict[str, int]]
- ) -> Dict[str, Dict[str, FetchKeyResult]]:
- """see KeyFetcher.get_keys"""
+ self.store = hs.get_datastore()
+ async def _fetch_keys(self, keys_to_fetch: List[_QueueValue]):
key_ids_to_fetch = (
- (server_name, key_id)
- for server_name, keys_for_server in keys_to_fetch.items()
- for key_id in keys_for_server.keys()
+ (queue_value.server_name, key_id)
+ for queue_value in keys_to_fetch
+ for key_id in queue_value.key_ids
)
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
@@ -500,6 +351,8 @@ class StoreKeyFetcher(KeyFetcher):
class BaseV2KeyFetcher(KeyFetcher):
def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
self.store = hs.get_datastore()
self.config = hs.config
@@ -607,10 +460,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers
- async def get_keys(
- self, keys_to_fetch: Dict[str, Dict[str, int]]
+ async def _fetch_keys(
+ self, keys_to_fetch: List[_QueueValue]
) -> Dict[str, Dict[str, FetchKeyResult]]:
- """see KeyFetcher.get_keys"""
+ """see KeyFetcher._fetch_keys"""
async def get_key(key_server: TrustedKeyServer) -> Dict:
try:
@@ -646,12 +499,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return union_of_keys
async def get_server_verify_key_v2_indirect(
- self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+ self, keys_to_fetch: List[_QueueValue], key_server: TrustedKeyServer
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
keys_to_fetch:
- the keys to be fetched. server_name -> key_id -> min_valid_ts
+ the keys to be fetched.
key_server: notary server to query for the keys
@@ -665,7 +518,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
perspective_name = key_server.server_name
logger.info(
"Requesting keys %s from notary server %s",
- keys_to_fetch.items(),
+ keys_to_fetch,
perspective_name,
)
@@ -675,11 +528,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
path="/_matrix/key/v2/query",
data={
"server_keys": {
- server_name: {
- key_id: {"minimum_valid_until_ts": min_valid_ts}
- for key_id, min_valid_ts in server_keys.items()
+ queue_value.server_name: {
+ key_id: {
+ "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
+ }
+ for key_id in queue_value.key_ids
}
- for server_name, server_keys in keys_to_fetch.items()
+ for queue_value in keys_to_fetch
}
},
)
@@ -779,8 +634,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
- async def get_keys(
- self, keys_to_fetch: Dict[str, Dict[str, int]]
+ async def _fetch_keys(
+ self, keys_to_fetch: List[_QueueValue]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
@@ -793,8 +648,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
results = {}
- async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
- server_name, key_ids = key_to_fetch_item
+ async def get_key(key_to_fetch_item: _QueueValue) -> None:
+ server_name = key_to_fetch_item.server_name
+ key_ids = key_to_fetch_item.key_ids
+
try:
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys
@@ -805,7 +662,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
- await yieldable_gather_results(get_key, keys_to_fetch.items())
+ await yieldable_gather_results(get_key, keys_to_fetch)
return results
async def get_server_verify_key_v2_direct(
@@ -877,37 +734,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
keys.update(response_keys)
return keys
-
-
-async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
- """Waits for the key to become available, and then performs a verification
-
- Args:
- verify_request:
-
- Raises:
- SynapseError if there was a problem performing the verification
- """
- server_name = verify_request.server_name
- with PreserveLoggingContext():
- _, key_id, verify_key = await verify_request.key_ready
-
- json_object = verify_request.json_object
-
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except SignatureVerifyException as e:
- logger.debug(
- "Error verifying signature for %s:%s:%s with key %s: %s",
- server_name,
- verify_key.alg,
- verify_key.version,
- encode_verify_key_base64(verify_key),
- str(e),
- )
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s: %s"
- % (server_name, verify_key.alg, verify_key.version, str(e)),
- Codes.UNAUTHORIZED,
- )
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index a6471ce6cc..dd631c7794 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -73,10 +73,10 @@ class FederationBase:
* throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel
"""
- deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
-
ctx = current_context()
+ deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
+
@defer.inlineCallbacks
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
@@ -135,11 +135,7 @@ class FederationBase:
return deferreds
-class PduToCheckSig(
- namedtuple(
- "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
- )
-):
+class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
pass
@@ -182,7 +178,6 @@ def _check_sigs_on_pdus(
pdus_to_check = [
PduToCheckSig(
pdu=p,
- redacted_pdu_json=prune_event(p).get_pdu_json(),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
@@ -193,13 +188,12 @@ def _check_sigs_on_pdus(
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
- more_deferreds = keyring.verify_json_objects_for_server(
+ more_deferreds = keyring.verify_events_for_server(
[
(
p.sender_domain,
- p.redacted_pdu_json,
+ p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
- p.pdu.event_id,
)
for p in pdus_to_check_sender
]
@@ -228,13 +222,12 @@ def _check_sigs_on_pdus(
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
- more_deferreds = keyring.verify_json_objects_for_server(
+ more_deferreds = keyring.verify_events_for_server(
[
(
get_domain_from_id(p.pdu.event_id),
- p.redacted_pdu_json,
+ p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
- p.pdu.event_id,
)
for p in pdus_to_check_event_id
]
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index a5b6a61195..90acc23886 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -33,6 +33,7 @@ from typing import (
)
import attr
+import ijson
from prometheus_client import Counter
from twisted.internet import defer
@@ -55,11 +56,16 @@ from synapse.api.room_versions import (
)
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable, preserve_fn
+from synapse.logging.context import (
+ get_thread_resource_usage,
+ make_deferred_yieldable,
+ preserve_fn,
+)
from synapse.logging.utils import log_function
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.iterutils import batch_iter
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
@@ -385,7 +391,6 @@ class FederationClient(FederationBase):
Returns:
A list of PDUs that have valid signatures and hashes.
"""
- deferreds = self._check_sigs_and_hashes(room_version, pdus)
async def handle_check_result(pdu: EventBase, deferred: Deferred):
try:
@@ -420,6 +425,7 @@ class FederationClient(FederationBase):
return res
handle = preserve_fn(handle_check_result)
+ deferreds = self._check_sigs_and_hashes(room_version, pdus)
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
valid_pdus = await make_deferred_yieldable(
@@ -667,19 +673,37 @@ class FederationClient(FederationBase):
async def send_request(destination) -> Dict[str, Any]:
content = await self._do_send_join(destination, pdu)
- logger.debug("Got content: %s", content)
+ # logger.debug("Got content: %s", content.getvalue())
- state = [
- event_from_pdu_json(p, room_version, outlier=True)
- for p in content.get("state", [])
- ]
+ # logger.info("send_join content: %d", len(content))
- auth_chain = [
- event_from_pdu_json(p, room_version, outlier=True)
- for p in content.get("auth_chain", [])
- ]
+ content.seek(0)
+
+ r = get_thread_resource_usage()
+ logger.info("Memory before state: %s", r.ru_maxrss)
+
+ state = []
+ for i, p in enumerate(ijson.items(content, "state.item")):
+ state.append(event_from_pdu_json(p, room_version, outlier=True))
+ if i % 1000 == 999:
+ await self._clock.sleep(0)
+
+ r = get_thread_resource_usage()
+ logger.info("Memory after state: %s", r.ru_maxrss)
+
+ logger.info("Parsed state: %d", len(state))
+ content.seek(0)
+
+ auth_chain = []
+ for i, p in enumerate(ijson.items(content, "auth_chain.item")):
+ auth_chain.append(event_from_pdu_json(p, room_version, outlier=True))
+ if i % 1000 == 999:
+ await self._clock.sleep(0)
- pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
+ r = get_thread_resource_usage()
+ logger.info("Memory after: %s", r.ru_maxrss)
+
+ logger.info("Parsed auth chain: %d", len(auth_chain))
create_event = None
for e in state:
@@ -704,12 +728,19 @@ class FederationClient(FederationBase):
% (create_room_version,)
)
- valid_pdus = await self._check_sigs_and_hash_and_fetch(
- destination,
- list(pdus.values()),
- outlier=True,
- room_version=room_version,
- )
+ valid_pdus = []
+
+ for chunk in batch_iter(itertools.chain(state, auth_chain), 1000):
+ logger.info("Handling next _check_sigs_and_hash_and_fetch chunk")
+ new_valid_pdus = await self._check_sigs_and_hash_and_fetch(
+ destination,
+ chunk,
+ outlier=True,
+ room_version=room_version,
+ )
+ valid_pdus.extend(new_valid_pdus)
+
+ logger.info("_check_sigs_and_hash_and_fetch done")
valid_pdus_map = {p.event_id: p for p in valid_pdus}
@@ -744,6 +775,8 @@ class FederationClient(FederationBase):
% (auth_chain_create_events,)
)
+ logger.info("Returning from send_join")
+
return {
"state": signed_state,
"auth_chain": signed_auth,
@@ -769,6 +802,8 @@ class FederationClient(FederationBase):
if not self._is_unknown_endpoint(e):
raise
+ raise NotImplementedError()
+
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
resp = await self.transport_layer.send_join_v1(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index ada322a81e..9c0a105f36 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -244,7 +244,10 @@ class TransportLayerClient:
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = await self.client.put_json(
- destination=destination, path=path, data=content
+ destination=destination,
+ path=path,
+ data=content,
+ return_string_io=True,
)
return response
@@ -254,7 +257,10 @@ class TransportLayerClient:
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
response = await self.client.put_json(
- destination=destination, path=path, data=content
+ destination=destination,
+ path=path,
+ data=content,
+ return_string_io=True,
)
return response
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index de1b14cde3..4064a2b859 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -78,7 +78,7 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association.
if not servers:
- users = await self.state.get_current_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
servers = {get_domain_from_id(u) for u in users}
if not servers:
@@ -270,7 +270,7 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND,
)
- users = await self.state.get_current_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
extra_servers = {get_domain_from_id(u) for u in users}
servers = set(extra_servers) | set(servers)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d82144d7fa..f134f1e234 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -103,7 +103,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
- users = await self.state.get_current_users_in_room(
+ users = await self.store.get_users_in_room(
event.room_id
) # type: Iterable[str]
else:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 9d867aaf4d..e56242c63e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -552,7 +552,7 @@ class FederationHandler(BaseHandler):
destination: str,
room_id: str,
event_id: str,
- ) -> Tuple[List[EventBase], List[EventBase]]:
+ ) -> List[EventBase]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
@@ -573,11 +573,10 @@ class FederationHandler(BaseHandler):
desired_events = set(state_event_ids + auth_event_ids)
- event_map = await self._get_events_from_store_or_dest(
+ failed_to_fetch = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
- failed_to_fetch = desired_events - event_map.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state/auth events for %s %s",
@@ -585,18 +584,44 @@ class FederationHandler(BaseHandler):
failed_to_fetch,
)
+ event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
+
remote_state = [
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
- auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
- auth_chain.sort(key=lambda e: e.depth)
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = [
+ (event_id, event.room_id)
+ for idx, event in enumerate(remote_state)
+ if event.room_id != room_id
+ ]
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned auth/state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+
+ if bad_events:
+ remote_state = [e for e in remote_state if e.room_id == room_id]
- return remote_state, auth_chain
+ return remote_state
async def _get_events_from_store_or_dest(
self, destination: str, room_id: str, event_ids: Iterable[str]
- ) -> Dict[str, EventBase]:
+ ) -> Set[str]:
"""Fetch events from a remote destination, checking if we already have them.
Persists any events we don't already have as outliers.
@@ -613,54 +638,25 @@ class FederationHandler(BaseHandler):
Returns:
map from event_id to event
"""
- fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
-
- missing_events = set(event_ids) - fetched_events.keys()
+ have_events = await self.store.have_seen_events(event_ids)
- if missing_events:
- logger.debug(
- "Fetching unknown state/auth events %s for room %s",
- missing_events,
- room_id,
- )
-
- await self._get_events_and_persist(
- destination=destination, room_id=room_id, events=missing_events
- )
-
- # we need to make sure we re-load from the database to get the rejected
- # state correct.
- fetched_events.update(
- (await self.store.get_events(missing_events, allow_rejected=True))
- )
-
- # check for events which were in the wrong room.
- #
- # this can happen if a remote server claims that the state or
- # auth_events at an event in room A are actually events in room B
+ missing_events = set(event_ids) - have_events
- bad_events = [
- (event_id, event.room_id)
- for event_id, event in fetched_events.items()
- if event.room_id != room_id
- ]
+ if not missing_events:
+ return set()
- for bad_event_id, bad_room_id in bad_events:
- # This is a bogus situation, but since we may only discover it a long time
- # after it happened, we try our best to carry on, by just omitting the
- # bad events from the returned auth/state set.
- logger.warning(
- "Remote server %s claims event %s in room %s is an auth/state "
- "event in room %s",
- destination,
- bad_event_id,
- bad_room_id,
- room_id,
- )
+ logger.debug(
+ "Fetching unknown state/auth events %s for room %s",
+ missing_events,
+ room_id,
+ )
- del fetched_events[bad_event_id]
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
+ )
- return fetched_events
+ new_events = await self.store.have_seen_events(missing_events)
+ return missing_events - new_events
async def _get_state_after_missing_prev_event(
self,
@@ -963,27 +959,23 @@ class FederationHandler(BaseHandler):
# For each edge get the current state.
- auth_events = {}
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = await self._get_state_for_room(
+ state = await self._get_state_for_room(
destination=dest,
room_id=room_id,
event_id=e_id,
)
- auth_events.update({a.event_id: a for a in auth})
- auth_events.update({s.event_id: s for s in state})
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
required_auth = {
a_id
- for event in events
- + list(state_events.values())
- + list(auth_events.values())
+ for event in events + list(state_events.values())
for a_id in event.auth_event_ids()
}
+ auth_events = await self.store.get_events(required_auth, allow_rejected=True)
auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
@@ -1452,7 +1444,7 @@ class FederationHandler(BaseHandler):
# room stuff after join currently doesn't work on workers.
assert self.config.worker.worker_app is None
- logger.debug("Joining %s to %s", joinee, room_id)
+ logger.info("Joining %s to %s", joinee, room_id)
origin, event, room_version_obj = await self._make_and_verify_event(
target_hosts,
@@ -1463,6 +1455,8 @@ class FederationHandler(BaseHandler):
params={"ver": KNOWN_ROOM_VERSIONS},
)
+ logger.info("make_join done from %s", origin)
+
# This shouldn't happen, because the RoomMemberHandler has a
# linearizer lock which only allows one operation per user per room
# at a time - so this is just paranoia.
@@ -1482,10 +1476,13 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
+ logger.info("Sending join")
ret = await self.federation_client.send_join(
host_list, event, room_version_obj
)
+ logger.info("send join done")
+
origin = ret["origin"]
state = ret["state"]
auth_chain = ret["auth_chain"]
@@ -1510,10 +1507,14 @@ class FederationHandler(BaseHandler):
room_version=room_version_obj,
)
+ logger.info("Persisting auth true")
+
max_stream_id = await self._persist_auth_tree(
origin, room_id, auth_chain, state, event, room_version_obj
)
+ logger.info("Persisted auth true")
+
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
await self._replication.wait_for_stream_position(
@@ -2166,6 +2167,8 @@ class FederationHandler(BaseHandler):
ctx = await self.state_handler.compute_event_context(e)
events_to_context[e.event_id] = ctx
+ logger.info("Computed contexts")
+
event_map = {
e.event_id: e for e in itertools.chain(auth_events, state, [event])
}
@@ -2207,6 +2210,8 @@ class FederationHandler(BaseHandler):
else:
logger.info("Failed to find auth event %r", e_id)
+ logger.info("Got missing events")
+
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
@@ -2231,6 +2236,8 @@ class FederationHandler(BaseHandler):
raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
+ logger.info("Authed events")
+
await self.persist_events_and_notify(
room_id,
[
@@ -2239,10 +2246,14 @@ class FederationHandler(BaseHandler):
],
)
+ logger.info("Persisted events")
+
new_event_context = await self.state_handler.compute_event_context(
event, old_state=state
)
+ logger.info("Computed context")
+
return await self.persist_events_and_notify(
room_id, [(event, new_event_context)]
)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d4684755bc..9919cccb19 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -258,7 +258,7 @@ class MessageHandler:
"Getting joined members after leaving is not implemented"
)
- users_with_profile = await self.state.get_current_users_in_room(room_id)
+ users_with_profile = await self.store.get_users_in_room_with_profiles(room_id)
# If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or because there
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index ebbc234334..6fd1f34289 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1183,7 +1183,16 @@ class PresenceHandler(BasePresenceHandler):
max_pos, deltas = await self.store.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
- await self._handle_state_delta(deltas)
+
+ # We may get multiple deltas for different rooms, but we want to
+ # handle them on a room by room basis, so we batch them up by
+ # room.
+ deltas_by_room: Dict[str, List[JsonDict]] = {}
+ for delta in deltas:
+ deltas_by_room.setdefault(delta["room_id"], []).append(delta)
+
+ for room_id, deltas_for_room in deltas_by_room.items():
+ await self._handle_state_delta(room_id, deltas_for_room)
self._event_pos = max_pos
@@ -1192,17 +1201,21 @@ class PresenceHandler(BasePresenceHandler):
max_pos
)
- async def _handle_state_delta(self, deltas: List[JsonDict]) -> None:
- """Process current state deltas to find new joins that need to be
- handled.
+ async def _handle_state_delta(self, room_id: str, deltas: List[JsonDict]) -> None:
+ """Process current state deltas for the room to find new joins that need
+ to be handled.
"""
- # A map of destination to a set of user state that they should receive
- presence_destinations = {} # type: Dict[str, Set[UserPresenceState]]
+
+ # Sets of newly joined users. Note that if the local server is
+ # joining a remote room for the first time we'll see both the joining
+ # user and all remote users as newly joined.
+ newly_joined_users = set()
for delta in deltas:
+ assert room_id == delta["room_id"]
+
typ = delta["type"]
state_key = delta["state_key"]
- room_id = delta["room_id"]
event_id = delta["event_id"]
prev_event_id = delta["prev_event_id"]
@@ -1231,72 +1244,55 @@ class PresenceHandler(BasePresenceHandler):
# Ignore changes to join events.
continue
- # Retrieve any user presence state updates that need to be sent as a result,
- # and the destinations that need to receive it
- destinations, user_presence_states = await self._on_user_joined_room(
- room_id, state_key
- )
-
- # Insert the destinations and respective updates into our destinations dict
- for destination in destinations:
- presence_destinations.setdefault(destination, set()).update(
- user_presence_states
- )
-
- # Send out user presence updates for each destination
- for destination, user_state_set in presence_destinations.items():
- self._federation_queue.send_presence_to_destinations(
- destinations=[destination], states=user_state_set
- )
-
- async def _on_user_joined_room(
- self, room_id: str, user_id: str
- ) -> Tuple[List[str], List[UserPresenceState]]:
- """Called when we detect a user joining the room via the current state
- delta stream. Returns the destinations that need to be updated and the
- presence updates to send to them.
-
- Args:
- room_id: The ID of the room that the user has joined.
- user_id: The ID of the user that has joined the room.
-
- Returns:
- A tuple of destinations and presence updates to send to them.
- """
- if self.is_mine_id(user_id):
- # If this is a local user then we need to send their presence
- # out to hosts in the room (who don't already have it)
-
- # TODO: We should be able to filter the hosts down to those that
- # haven't previously seen the user
-
- remote_hosts = await self.state.get_current_hosts_in_room(room_id)
+ newly_joined_users.add(state_key)
- # Filter out ourselves.
- filtered_remote_hosts = [
- host for host in remote_hosts if host != self.server_name
- ]
-
- state = await self.current_state_for_user(user_id)
- return filtered_remote_hosts, [state]
- else:
- # A remote user has joined the room, so we need to:
- # 1. Check if this is a new server in the room
- # 2. If so send any presence they don't already have for
- # local users in the room.
-
- # TODO: We should be able to filter the users down to those that
- # the server hasn't previously seen
-
- # TODO: Check that this is actually a new server joining the
- # room.
-
- remote_host = get_domain_from_id(user_id)
+ if not newly_joined_users:
+ # If nobody has joined then there's nothing to do.
+ return
- users = await self.state.get_current_users_in_room(room_id)
- user_ids = list(filter(self.is_mine_id, users))
+ # We want to send:
+ # 1. presence states of all local users in the room to newly joined
+ # remote servers
+ # 2. presence states of newly joined users to all remote servers in
+ # the room.
+ #
+ # TODO: Only send presence states to remote hosts that don't already
+ # have them (because they already share rooms).
+
+ # Get all the users who were already in the room, by fetching the
+ # current users in the room and removing the newly joined users.
+ users = await self.store.get_users_in_room(room_id)
+ prev_users = set(users) - newly_joined_users
+
+ # Construct sets for all the local users and remote hosts that were
+ # already in the room
+ prev_local_users = []
+ prev_remote_hosts = set()
+ for user_id in prev_users:
+ if self.is_mine_id(user_id):
+ prev_local_users.append(user_id)
+ else:
+ prev_remote_hosts.add(get_domain_from_id(user_id))
+
+ # Similarly, construct sets for all the local users and remote hosts
+ # that were *not* already in the room. Care needs to be taken with the
+ # calculating the remote hosts, as a host may have already been in the
+ # room even if there is a newly joined user from that host.
+ newly_joined_local_users = []
+ newly_joined_remote_hosts = set()
+ for user_id in newly_joined_users:
+ if self.is_mine_id(user_id):
+ newly_joined_local_users.append(user_id)
+ else:
+ host = get_domain_from_id(user_id)
+ if host not in prev_remote_hosts:
+ newly_joined_remote_hosts.add(host)
- states_d = await self.current_state_for_users(user_ids)
+ # Send presence states of all local users in the room to newly joined
+ # remote servers. (We actually only send states for local users already
+ # in the room, as we'll send states for newly joined local users below.)
+ if prev_local_users and newly_joined_remote_hosts:
+ local_states = await self.current_state_for_users(prev_local_users)
# Filter out old presence, i.e. offline presence states where
# the user hasn't been active for a week. We can change this
@@ -1306,13 +1302,27 @@ class PresenceHandler(BasePresenceHandler):
now = self.clock.time_msec()
states = [
state
- for state in states_d.values()
+ for state in local_states.values()
if state.state != PresenceState.OFFLINE
or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
or state.status_msg is not None
]
- return [remote_host], states
+ self._federation_queue.send_presence_to_destinations(
+ destinations=newly_joined_remote_hosts,
+ states=states,
+ )
+
+ # Send presence states of newly joined users to all remote servers in
+ # the room
+ if newly_joined_local_users and (
+ prev_remote_hosts or newly_joined_remote_hosts
+ ):
+ local_states = await self.current_state_for_users(newly_joined_local_users)
+ self._federation_queue.send_presence_to_destinations(
+ destinations=prev_remote_hosts | newly_joined_remote_hosts,
+ states=list(local_states.values()),
+ )
def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> bool:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 88e9d0e2fe..0b805af60c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1327,7 +1327,7 @@ class RoomShutdownHandler:
new_room_id = None
logger.info("Shutting down room %r", room_id)
- users = await self.state.get_current_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
kicked_users = []
failed_to_kick_users = []
for user_id in users:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index a9a3ee05c3..0fcc1532da 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1190,7 +1190,7 @@ class SyncHandler:
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
- joined_users = await self.state.get_current_users_in_room(room_id)
+ joined_users = await self.store.get_users_in_room(room_id)
newly_joined_or_invited_users.update(joined_users)
# TODO: Check that these users are actually new, i.e. either they
@@ -1206,7 +1206,7 @@ class SyncHandler:
# Now find users that we no longer track
for room_id in newly_left_rooms:
- left_users = await self.state.get_current_users_in_room(room_id)
+ left_users = await self.store.get_users_in_room(room_id)
newly_left_users.update(left_users)
# Remove any users that we still share a room with.
@@ -1361,7 +1361,7 @@ class SyncHandler:
extra_users_ids = set(newly_joined_or_invited_users)
for room_id in newly_joined_rooms:
- users = await self.state.get_current_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
extra_users_ids.update(users)
extra_users_ids.discard(user.to_string())
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index bb837b7b19..6db1aece35 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -154,6 +154,7 @@ async def _handle_json_response(
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
+ return_string_io=False,
) -> JsonDict:
"""
Reads the JSON body of a response, with a timeout
@@ -175,12 +176,12 @@ async def _handle_json_response(
d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE)
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
- def parse(_len: int):
- return json_decoder.decode(buf.getvalue())
+ await make_deferred_yieldable(d)
- d.addCallback(parse)
-
- body = await make_deferred_yieldable(d)
+ if return_string_io:
+ body = buf
+ else:
+ body = json_decoder.decode(buf.getvalue())
except BodyExceededMaxSize as e:
# The response was too big.
logger.warning(
@@ -225,12 +226,13 @@ async def _handle_json_response(
time_taken_secs = reactor.seconds() - start_ms / 1000
logger.info(
- "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
+ "{%s} [%s] Completed request: %d %s in %.2f secs got %dB - %s %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
time_taken_secs,
+ len(buf.getvalue()),
request.method,
request.uri.decode("ascii"),
)
@@ -683,6 +685,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
+ return_string_io=False,
) -> Union[JsonDict, list]:
"""Sends the specified json data using PUT
@@ -757,7 +760,12 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
body = await _handle_json_response(
- self.reactor, _sec_timeout, request, response, start_ms
+ self.reactor,
+ _sec_timeout,
+ request,
+ response,
+ start_ms,
+ return_string_io=return_string_io,
)
return body
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 31b7b3c256..fef2846669 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -535,6 +535,13 @@ class ReactorLastSeenMetric:
REGISTRY.register(ReactorLastSeenMetric())
+# The minimum time in seconds between GCs for each generation, regardless of the current GC
+# thresholds and counts.
+MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
+
+# The time (in seconds since the epoch) of the last time we did a GC for each generation.
+_last_gc = [0.0, 0.0, 0.0]
+
def runUntilCurrentTimer(reactor, func):
@functools.wraps(func)
@@ -575,11 +582,16 @@ def runUntilCurrentTimer(reactor, func):
return ret
# Check if we need to do a manual GC (since its been disabled), and do
- # one if necessary.
+ # one if necessary. Note we go in reverse order as e.g. a gen 1 GC may
+ # promote an object into gen 2, and we don't want to handle the same
+ # object multiple times.
threshold = gc.get_threshold()
counts = gc.get_count()
for i in (2, 1, 0):
- if threshold[i] < counts[i]:
+ # We check if we need to do one based on a straightforward
+ # comparison between the threshold and count. We also do an extra
+ # check to make sure that we don't a GC too often.
+ if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]:
if i == 0:
logger.debug("Collecting gc %d", i)
else:
@@ -589,6 +601,8 @@ def runUntilCurrentTimer(reactor, func):
unreachable = gc.collect(i)
end = time.time()
+ _last_gc[i] = end
+
gc_time.labels(i).observe(end - start)
gc_unreachable.labels(i).set(unreachable)
@@ -615,6 +629,7 @@ try:
except AttributeError:
pass
+
__all__ = [
"MetricsResource",
"generate_latest",
diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
new file mode 100644
index 0000000000..07ed1d2453
--- /dev/null
+++ b/synapse/metrics/jemalloc.py
@@ -0,0 +1,191 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import ctypes
+import logging
+import os
+import re
+from typing import Optional
+
+from synapse.metrics import REGISTRY, GaugeMetricFamily
+
+logger = logging.getLogger(__name__)
+
+
+def _setup_jemalloc_stats():
+ """Checks to see if jemalloc is loaded, and hooks up a collector to record
+ statistics exposed by jemalloc.
+ """
+
+ # Try to find the loaded jemalloc shared library, if any. We need to
+ # introspect into what is loaded, rather than loading whatever is on the
+ # path, as if we load a *different* jemalloc version things will seg fault.
+
+ # We look in `/proc/self/maps`, which only exists on linux.
+ if not os.path.exists("/proc/self/maps"):
+ logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
+ return
+
+ # We're looking for a path at the end of the line that includes
+ # "libjemalloc".
+ regex = re.compile(r"/\S+/libjemalloc.*$")
+
+ jemalloc_path = None
+ with open("/proc/self/maps") as f:
+ for line in f:
+ match = regex.search(line.strip())
+ if match:
+ jemalloc_path = match.group()
+
+ if not jemalloc_path:
+ # No loaded jemalloc was found.
+ logger.debug("jemalloc not found")
+ return
+
+ jemalloc = ctypes.CDLL(jemalloc_path)
+
+ def _mallctl(
+ name: str, read: bool = True, write: Optional[int] = None
+ ) -> Optional[int]:
+ """Wrapper around `mallctl` for reading and writing integers to
+ jemalloc.
+
+ Args:
+ name: The name of the option to read from/write to.
+ read: Whether to try and read the value.
+ write: The value to write, if given.
+
+ Returns:
+ The value read if `read` is True, otherwise None.
+
+ Raises:
+ An exception if `mallctl` returns a non-zero error code.
+ """
+
+ input_var = None
+ input_var_ref = None
+ input_len_ref = None
+ if read:
+ input_var = ctypes.c_size_t(0)
+ input_len = ctypes.c_size_t(ctypes.sizeof(input_var))
+
+ input_var_ref = ctypes.byref(input_var)
+ input_len_ref = ctypes.byref(input_len)
+
+ write_var_ref = None
+ write_len = ctypes.c_size_t(0)
+ if write is not None:
+ write_var = ctypes.c_size_t(write)
+ write_len = ctypes.c_size_t(ctypes.sizeof(write_var))
+
+ write_var_ref = ctypes.byref(write_var)
+
+ # The interface is:
+ #
+ # int mallctl(
+ # const char *name,
+ # void *oldp,
+ # size_t *oldlenp,
+ # void *newp,
+ # size_t newlen
+ # )
+ #
+ # Where oldp/oldlenp is a buffer where the old value will be written to
+ # (if not null), and newp/newlen is the buffer with the new value to set
+ # (if not null). Note that they're all references *except* newlen.
+ result = jemalloc.mallctl(
+ name.encode("ascii"),
+ input_var_ref,
+ input_len_ref,
+ write_var_ref,
+ write_len,
+ )
+
+ if result != 0:
+ raise Exception("Failed to call mallctl")
+
+ if input_var is None:
+ return None
+
+ return input_var.value
+
+ def _jemalloc_refresh_stats() -> None:
+ """Request that jemalloc updates its internal statistics. This needs to
+ be called before querying for stats, otherwise it will return stale
+ values.
+ """
+ try:
+ _mallctl("epoch", read=False, write=1)
+ except Exception:
+ pass
+
+ class JemallocCollector:
+ """Metrics for internal jemalloc stats."""
+
+ def collect(self):
+ _jemalloc_refresh_stats()
+
+ g = GaugeMetricFamily(
+ "jemalloc_stats_app_memory_bytes",
+ "The stats reported by jemalloc",
+ labels=["type"],
+ )
+
+ # Read the relevant global stats from jemalloc. Note that these may
+ # not be accurate if python is configured to use its internal small
+ # object allocator (which is on by default, disable by setting the
+ # env `PYTHONMALLOC=malloc`).
+ #
+ # See the jemalloc manpage for details about what each value means,
+ # roughly:
+ # - allocated ─ Total number of bytes allocated by the app
+ # - active ─ Total number of bytes in active pages allocated by
+ # the application, this is bigger than `allocated`.
+ # - resident ─ Maximum number of bytes in physically resident data
+ # pages mapped by the allocator, comprising all pages dedicated
+ # to allocator metadata, pages backing active allocations, and
+ # unused dirty pages. This is bigger than `active`.
+ # - mapped ─ Total number of bytes in active extents mapped by the
+ # allocator.
+ # - metadata ─ Total number of bytes dedicated to jemalloc
+ # metadata.
+ for t in (
+ "allocated",
+ "active",
+ "resident",
+ "mapped",
+ "metadata",
+ ):
+ try:
+ value = _mallctl(f"stats.{t}")
+ except Exception:
+ # There was an error fetching the value, skip.
+ continue
+
+ g.add_metric([t], value=value)
+
+ yield g
+
+ REGISTRY.register(JemallocCollector())
+
+ logger.debug("Added jemalloc stats")
+
+
+def setup_jemalloc_stats():
+ """Try to setup jemalloc stats, if jemalloc is loaded."""
+
+ try:
+ _setup_jemalloc_stats()
+ except Exception as e:
+ logger.info("Failed to setup collector to record jemalloc stats: %s", e)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 2de946f464..d58eeeaa74 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -116,6 +116,8 @@ CONDITIONAL_REQUIREMENTS = {
# hiredis is not a *strict* dependency, but it makes things much faster.
# (if it is not installed, we fall back to slow code.)
"redis": ["txredisapi>=1.4.7", "hiredis"],
+ # Required to use experimental `caches.track_memory_usage` config option.
+ "cache_memory": ["pympler"],
}
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f648678b09..e19e9ef5c7 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.util import json_decoder
+from synapse.util.async_helpers import yieldable_gather_results
logger = logging.getLogger(__name__)
@@ -213,7 +214,13 @@ class RemoteKey(DirectServeJsonResource):
# If there is a cache miss, request the missing keys, then recurse (and
# ensure the result is sent).
if cache_misses and query_remote_on_cache_miss:
- await self.fetcher.get_keys(cache_misses)
+ await yieldable_gather_results(
+ lambda t: self.fetcher.get_keys(*t),
+ (
+ (server_name, list(keys), 0)
+ for server_name, keys in cache_misses.items()
+ ),
+ )
await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index b3bd92d37c..a1770f620e 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -213,19 +213,23 @@ class StateHandler:
return ret.state
async def get_current_users_in_room(
- self, room_id: str, latest_event_ids: Optional[List[str]] = None
+ self, room_id: str, latest_event_ids: List[str]
) -> Dict[str, ProfileInfo]:
"""
Get the users who are currently in a room.
+ Note: This is much slower than using the equivalent method
+ `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
+ so this should only be used when wanting the users at a particular point
+ in the room.
+
Args:
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
Dictionary of user IDs to their profileinfo.
"""
- if not latest_event_ids:
- latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_users_in_room")
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6b68d8720c..3d98d3f5f8 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -69,6 +69,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
+ self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,))
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2a8532f8c1..5fc3bb5a7d 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -205,8 +205,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
sql = """
- SELECT user_id, display_name, avatar_url FROM room_memberships
- WHERE room_id = ? AND membership = ?
+ SELECT state_key, display_name, avatar_url FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
"""
txn.execute(sql, (room_id, Membership.JOIN))
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 7a082fdd21..a6bfb4902a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -142,8 +142,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
batch_size (int): Maximum number of state events to process
per cycle.
"""
- state = self.hs.get_state_handler()
-
# If we don't have progress filed, delete everything.
if not progress:
await self.delete_all_from_user_dir()
@@ -197,7 +195,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
room_id
)
- users_with_profile = await state.get_current_users_in_room(room_id)
+ users_with_profile = await self.get_users_in_room_with_profiles(room_id)
user_ids = set(users_with_profile)
# Update each user in the user directory.
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 46af7fa473..ca36f07c20 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -24,6 +24,11 @@ from synapse.config.cache import add_resizable_cache
logger = logging.getLogger(__name__)
+
+# Whether to track estimated memory usage of the LruCaches.
+TRACK_MEMORY_USAGE = False
+
+
caches_by_name = {} # type: Dict[str, Sized]
collectors_by_name = {} # type: Dict[str, CacheMetric]
@@ -32,6 +37,11 @@ cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"])
cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
+cache_memory_usage = Gauge(
+ "synapse_util_caches_cache_size_bytes",
+ "Estimated memory usage of the caches",
+ ["name"],
+)
response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
@@ -52,6 +62,7 @@ class CacheMetric:
hits = attr.ib(default=0)
misses = attr.ib(default=0)
evicted_size = attr.ib(default=0)
+ memory_usage = attr.ib(default=None)
def inc_hits(self):
self.hits += 1
@@ -62,6 +73,19 @@ class CacheMetric:
def inc_evictions(self, size=1):
self.evicted_size += size
+ def inc_memory_usage(self, memory: int):
+ if self.memory_usage is None:
+ self.memory_usage = 0
+
+ self.memory_usage += memory
+
+ def dec_memory_usage(self, memory: int):
+ self.memory_usage -= memory
+
+ def clear_memory_usage(self):
+ if self.memory_usage is not None:
+ self.memory_usage = 0
+
def describe(self):
return []
@@ -81,6 +105,13 @@ class CacheMetric:
cache_total.labels(self._cache_name).set(self.hits + self.misses)
if getattr(self._cache, "max_size", None):
cache_max_size.labels(self._cache_name).set(self._cache.max_size)
+
+ if TRACK_MEMORY_USAGE:
+ # self.memory_usage can be None if nothing has been inserted
+ # into the cache yet.
+ cache_memory_usage.labels(self._cache_name).set(
+ self.memory_usage or 0
+ )
if self._collect_callback:
self._collect_callback()
except Exception as e:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 10b0ec6b75..1be675e014 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -32,9 +32,36 @@ from typing import (
from typing_extensions import Literal
from synapse.config import cache as cache_config
+from synapse.util import caches
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache
+try:
+ from pympler.asizeof import Asizer
+
+ def _get_size_of(val: Any, *, recurse=True) -> int:
+ """Get an estimate of the size in bytes of the object.
+
+ Args:
+ val: The object to size.
+ recurse: If true will include referenced values in the size,
+ otherwise only sizes the given object.
+ """
+ # Ignore singleton values when calculating memory usage.
+ if val in ((), None, ""):
+ return 0
+
+ sizer = Asizer()
+ sizer.exclude_refs((), None, "")
+ return sizer.asizeof(val, limit=100 if recurse else 0)
+
+
+except ImportError:
+
+ def _get_size_of(val: Any, *, recurse=True) -> int:
+ return 0
+
+
# Function type: the type used for invalidation callbacks
FT = TypeVar("FT", bound=Callable[..., Any])
@@ -56,7 +83,7 @@ def enumerate_leaves(node, depth):
class _Node:
- __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
+ __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
def __init__(
self,
@@ -84,6 +111,16 @@ class _Node:
self.add_callbacks(callbacks)
+ self.memory = 0
+ if caches.TRACK_MEMORY_USAGE:
+ self.memory = (
+ _get_size_of(key)
+ + _get_size_of(value)
+ + _get_size_of(self.callbacks, recurse=False)
+ + _get_size_of(self, recurse=False)
+ )
+ self.memory += _get_size_of(self.memory, recurse=False)
+
def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
"""Add to stored list of callbacks, removing duplicates."""
@@ -233,6 +270,9 @@ class LruCache(Generic[KT, VT]):
if size_callback:
cached_cache_len[0] += size_callback(node.value)
+ if caches.TRACK_MEMORY_USAGE and metrics:
+ metrics.inc_memory_usage(node.memory)
+
def move_node_to_front(node):
prev_node = node.prev_node
next_node = node.next_node
@@ -258,6 +298,9 @@ class LruCache(Generic[KT, VT]):
node.run_and_clear_callbacks()
+ if caches.TRACK_MEMORY_USAGE and metrics:
+ metrics.dec_memory_usage(node.memory)
+
return deleted_len
@overload
@@ -373,6 +416,9 @@ class LruCache(Generic[KT, VT]):
if size_callback:
cached_cache_len[0] = 0
+ if caches.TRACK_MEMORY_USAGE and metrics:
+ metrics.clear_memory_usage()
+
@synchronized
def cache_contains(key: KT) -> bool:
return key in cache
|