diff options
Diffstat (limited to 'synapse')
43 files changed, 1022 insertions, 594 deletions
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 092601f530..0c4504d5d8 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -1,6 +1,6 @@ # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector -# Copyright 2021 The Matrix.org Foundation C.I.C. +# Copyright 2021-22 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,11 +20,22 @@ import hashlib import hmac import logging import sys -from typing import Callable, Optional +from typing import Any, Callable, Dict, Optional import requests import yaml +_CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\ +Conflicting options 'registration_shared_secret' and 'registration_shared_secret_path' +are both defined in config file. +""" + +_NO_SHARED_SECRET_OPTS_ERROR = """\ +No 'registration_shared_secret' or 'registration_shared_secret_path' defined in config. +""" + +_DEFAULT_SERVER_URL = "http://localhost:8008" + def request_registration( user: str, @@ -203,31 +214,104 @@ def main() -> None: parser.add_argument( "server_url", - default="https://localhost:8448", nargs="?", - help="URL to use to talk to the homeserver. Defaults to " - " 'https://localhost:8448'.", + help="URL to use to talk to the homeserver. By default, tries to find a " + "suitable URL from the configuration file. Otherwise, defaults to " + f"'{_DEFAULT_SERVER_URL}'.", ) args = parser.parse_args() if "config" in args and args.config: config = yaml.safe_load(args.config) - secret = config.get("registration_shared_secret", None) + + if args.shared_secret: + secret = args.shared_secret + else: + # argparse should check that we have either config or shared secret + assert config + + secret = config.get("registration_shared_secret") + secret_file = config.get("registration_shared_secret_path") + if secret_file: + if secret: + print(_CONFLICTING_SHARED_SECRET_OPTS_ERROR, file=sys.stderr) + sys.exit(1) + secret = _read_file(secret_file, "registration_shared_secret_path").strip() if not secret: - print("No 'registration_shared_secret' defined in config.") + print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr) sys.exit(1) + + if args.server_url: + server_url = args.server_url + elif config: + server_url = _find_client_listener(config) + if not server_url: + server_url = _DEFAULT_SERVER_URL + print( + "Unable to find a suitable HTTP listener in the configuration file. " + f"Trying {server_url} as a last resort.", + file=sys.stderr, + ) else: - secret = args.shared_secret + server_url = _DEFAULT_SERVER_URL + print( + f"No server url or configuration file given. Defaulting to {server_url}.", + file=sys.stderr, + ) admin = None if args.admin or args.no_admin: admin = args.admin register_new_user( - args.user, args.password, args.server_url, secret, admin, args.user_type + args.user, args.password, server_url, secret, admin, args.user_type ) +def _read_file(file_path: Any, config_path: str) -> str: + """Check the given file exists, and read it into a string + + If it does not, exit with an error indicating the problem + + Args: + file_path: the file to be read + config_path: where in the configuration file_path came from, so that a useful + error can be emitted if it does not exist. + Returns: + content of the file. + """ + if not isinstance(file_path, str): + print(f"{config_path} setting is not a string", file=sys.stderr) + sys.exit(1) + + try: + with open(file_path) as file_stream: + return file_stream.read() + except OSError as e: + print(f"Error accessing file {file_path}: {e}", file=sys.stderr) + sys.exit(1) + + +def _find_client_listener(config: Dict[str, Any]) -> Optional[str]: + # try to find a listener in the config. Returns a host:port pair + for listener in config.get("listeners", []): + if listener.get("type") != "http" or listener.get("tls", False): + continue + + if not any( + name == "client" + for resource in listener.get("resources", []) + for name in resource.get("names", []) + ): + continue + + # TODO: consider bind_addresses + return f"http://localhost:{listener['port']}" + + # no suitable listeners? + return None + + if __name__ == "__main__": main() diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 923891ae0d..4742435d3b 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -266,15 +266,48 @@ def register_start( reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: +def listen_metrics( + bind_addresses: Iterable[str], port: int, enable_legacy_metric_names: bool +) -> None: """ Start Prometheus metrics server. """ - from synapse.metrics import RegistryProxy, start_http_server + from prometheus_client import start_http_server as start_http_server_prometheus + + from synapse.metrics import ( + RegistryProxy, + start_http_server as start_http_server_legacy, + ) for host in bind_addresses: logger.info("Starting metrics listener on %s:%d", host, port) - start_http_server(port, addr=host, registry=RegistryProxy) + if enable_legacy_metric_names: + start_http_server_legacy(port, addr=host, registry=RegistryProxy) + else: + _set_prometheus_client_use_created_metrics(False) + start_http_server_prometheus(port, addr=host, registry=RegistryProxy) + + +def _set_prometheus_client_use_created_metrics(new_value: bool) -> None: + """ + Sets whether prometheus_client should expose `_created`-suffixed metrics for + all gauges, histograms and summaries. + There is no programmatic way to disable this without poking at internals; + the proper way is to use an environment variable which prometheus_client + loads at import time. + + The motivation for disabling these `_created` metrics is that they're + a waste of space as they're not useful but they take up space in Prometheus. + """ + + import prometheus_client.metrics + + if hasattr(prometheus_client.metrics, "_use_created"): + prometheus_client.metrics._use_created = new_value + else: + logger.error( + "Can't disable `_created` metrics in prometheus_client (brittle hack broken?)" + ) def listen_manhole( diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 30e21d9707..5e3825fca6 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -412,7 +412,11 @@ class GenericWorkerServer(HomeServer): "enable_metrics is not True!" ) else: - _base.listen_metrics(listener.bind_addresses, listener.port) + _base.listen_metrics( + listener.bind_addresses, + listener.port, + enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics, + ) else: logger.warning("Unsupported listener type: %s", listener.type) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 68993d91a9..e57a926032 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -307,7 +307,11 @@ class SynapseHomeServer(HomeServer): "enable_metrics is not True!" ) else: - _base.listen_metrics(listener.bind_addresses, listener.port) + _base.listen_metrics( + listener.bind_addresses, + listener.port, + enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics, + ) else: # this shouldn't happen, as the listener type should have been checked # during parsing diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 7c9cf403ef..1f6362aedd 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -20,6 +20,7 @@ import logging import os import re from collections import OrderedDict +from enum import Enum, auto from hashlib import sha256 from textwrap import dedent from typing import ( @@ -603,18 +604,44 @@ class RootConfig: " may specify directories containing *.yaml files.", ) - generate_group = parser.add_argument_group("Config generation") - generate_group.add_argument( + # we nest the mutually-exclusive group inside another group so that the help + # text shows them in their own group. + generate_mode_group = parser.add_argument_group( + "Config generation mode", + ) + generate_mode_exclusive = generate_mode_group.add_mutually_exclusive_group() + generate_mode_exclusive.add_argument( + # hidden option to make the type and default work + "--generate-mode", + help=argparse.SUPPRESS, + type=_ConfigGenerateMode, + default=_ConfigGenerateMode.GENERATE_MISSING_AND_RUN, + ) + generate_mode_exclusive.add_argument( "--generate-config", - action="store_true", help="Generate a config file, then exit.", + action="store_const", + const=_ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT, + dest="generate_mode", ) - generate_group.add_argument( + generate_mode_exclusive.add_argument( "--generate-missing-configs", "--generate-keys", - action="store_true", help="Generate any missing additional config files, then exit.", + action="store_const", + const=_ConfigGenerateMode.GENERATE_MISSING_AND_EXIT, + dest="generate_mode", ) + generate_mode_exclusive.add_argument( + "--generate-missing-and-run", + help="Generate any missing additional config files, then run. This is the " + "default behaviour.", + action="store_const", + const=_ConfigGenerateMode.GENERATE_MISSING_AND_RUN, + dest="generate_mode", + ) + + generate_group = parser.add_argument_group("Details for --generate-config") generate_group.add_argument( "-H", "--server-name", help="The server name to generate a config file for." ) @@ -670,11 +697,12 @@ class RootConfig: config_dir_path = os.path.abspath(config_dir_path) data_dir_path = os.getcwd() - generate_missing_configs = config_args.generate_missing_configs - obj = cls(config_files) - if config_args.generate_config: + if ( + config_args.generate_mode + == _ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT + ): if config_args.report_stats is None: parser.error( "Please specify either --report-stats=yes or --report-stats=no\n\n" @@ -732,11 +760,14 @@ class RootConfig: ) % (config_path,) ) - generate_missing_configs = True config_dict = read_config_files(config_files) - if generate_missing_configs: - obj.generate_missing_files(config_dict, config_dir_path) + obj.generate_missing_files(config_dict, config_dir_path) + + if config_args.generate_mode in ( + _ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT, + _ConfigGenerateMode.GENERATE_MISSING_AND_EXIT, + ): return None obj.parse_config_dict( @@ -965,6 +996,12 @@ def read_file(file_path: Any, config_path: Iterable[str]) -> str: raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e +class _ConfigGenerateMode(Enum): + GENERATE_MISSING_AND_RUN = auto() + GENERATE_MISSING_AND_EXIT = auto() + GENERATE_EVERYTHING_AND_EXIT = auto() + + __all__ = [ "Config", "RootConfig", diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 3b42be5b5b..f3134834e5 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -42,6 +42,35 @@ class MetricsConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.enable_metrics = config.get("enable_metrics", False) + + """ + ### `enable_legacy_metrics` (experimental) + + **Experimental: this option may be removed or have its behaviour + changed at any time, with no notice.** + + Set to `true` to publish both legacy and non-legacy Prometheus metric names, + or to `false` to only publish non-legacy Prometheus metric names. + Defaults to `true`. Has no effect if `enable_metrics` is `false`. + + Legacy metric names include: + - metrics containing colons in the name, such as `synapse_util_caches_response_cache:hits`, because colons are supposed to be reserved for user-defined recording rules; + - counters that don't end with the `_total` suffix, such as `synapse_federation_client_sent_edus`, therefore not adhering to the OpenMetrics standard. + + These legacy metric names are unconventional and not compliant with OpenMetrics standards. + They are included for backwards compatibility. + + Example configuration: + ```yaml + enable_legacy_metrics: false + ``` + + See https://github.com/matrix-org/synapse/issues/11106 for context. + + *Since v1.67.0.* + """ + self.enable_legacy_metrics = config.get("enable_legacy_metrics", True) + self.report_stats = config.get("report_stats", None) self.report_stats_endpoint = config.get( "report_stats_endpoint", "https://matrix.org/report-usage-stats/push" diff --git a/synapse/config/registration.py b/synapse/config/registration.py index a888d976f2..df1d83dfaa 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -from typing import Any, Optional +from typing import Any, Dict, Optional from synapse.api.constants import RoomCreationPreset -from synapse.config._base import Config, ConfigError +from synapse.config._base import Config, ConfigError, read_file from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string_with_symbols, strtobool @@ -27,6 +27,11 @@ password resets, configure Synapse with an SMTP server via the `email` setting, remove `account_threepid_delegates.email`. """ +CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\ +You have configured both `registration_shared_secret` and +`registration_shared_secret_path`. These are mutually incompatible. +""" + class RegistrationConfig(Config): section = "registration" @@ -53,7 +58,16 @@ class RegistrationConfig(Config): self.enable_registration_token_3pid_bypass = config.get( "enable_registration_token_3pid_bypass", False ) + + # read the shared secret, either inline or from an external file self.registration_shared_secret = config.get("registration_shared_secret") + registration_shared_secret_path = config.get("registration_shared_secret_path") + if registration_shared_secret_path: + if self.registration_shared_secret: + raise ConfigError(CONFLICTING_SHARED_SECRET_OPTS_ERROR) + self.registration_shared_secret = read_file( + registration_shared_secret_path, ("registration_shared_secret_path",) + ).strip() self.bcrypt_rounds = config.get("bcrypt_rounds", 12) @@ -218,6 +232,21 @@ class RegistrationConfig(Config): else: return "" + def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None: + # if 'registration_shared_secret_path' is specified, and the target file + # does not exist, generate it. + registration_shared_secret_path = config.get("registration_shared_secret_path") + if registration_shared_secret_path and not self.path_exists( + registration_shared_secret_path + ): + print( + "Generating registration shared secret file " + + registration_shared_secret_path + ) + secret = random_string_with_symbols(50) + with open(registration_shared_secret_path, "w") as f: + f.write(f"{secret}\n") + @staticmethod def add_arguments(parser: argparse.ArgumentParser) -> None: reg_group = parser.add_argument_group("registration") diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 7520647d1e..23b799ac32 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -28,6 +28,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.events import EventBase from synapse.events.utils import prune_event, prune_event_dict +from synapse.logging.opentracing import trace from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ logger = logging.getLogger(__name__) Hasher = Callable[[bytes], "hashlib._Hash"] +@trace def check_event_content_hash( event: EventBase, hash_algorithm: Hasher = hashlib.sha256 ) -> bool: diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 4a3bfb38f1..623a2c71ea 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -32,6 +32,7 @@ from typing_extensions import Literal import synapse from synapse.api.errors import Codes +from synapse.logging.opentracing import trace from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour @@ -378,6 +379,7 @@ class SpamChecker: if check_media_file_for_spam is not None: self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam) + @trace async def check_event_for_spam( self, event: "synapse.events.EventBase" ) -> Union[Tuple[Codes, JsonDict], str]: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 2522bf78fc..4269a98db2 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -23,6 +23,7 @@ from synapse.crypto.keyring import Keyring from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event, validate_canonicaljson from synapse.http.servlet import assert_params_in_dict +from synapse.logging.opentracing import log_kv, trace from synapse.types import JsonDict, get_domain_from_id if TYPE_CHECKING: @@ -55,6 +56,7 @@ class FederationBase: self._clock = hs.get_clock() self._storage_controllers = hs.get_storage_controllers() + @trace async def _check_sigs_and_hash( self, room_version: RoomVersion, pdu: EventBase ) -> EventBase: @@ -97,17 +99,36 @@ class FederationBase: "Event %s seems to have been redacted; using our redacted copy", pdu.event_id, ) + log_kv( + { + "message": "Event seems to have been redacted; using our redacted copy", + "event_id": pdu.event_id, + } + ) else: logger.warning( "Event %s content has been tampered, redacting", pdu.event_id, ) + log_kv( + { + "message": "Event content has been tampered, redacting", + "event_id": pdu.event_id, + } + ) return redacted_event spam_check = await self.spam_checker.check_event_for_spam(pdu) if spam_check != self.spam_checker.NOT_SPAM: logger.warning("Event contains spam, soft-failing %s", pdu.event_id) + log_kv( + { + "message": "Event contains spam, redacting (to save disk space) " + "as well as soft-failing (to stop using the event in prev_events)", + "event_id": pdu.event_id, + } + ) # we redact (to save disk space) as well as soft-failing (to stop # using the event in prev_events). redacted_event = prune_event(pdu) @@ -117,6 +138,7 @@ class FederationBase: return pdu +@trace async def _check_sigs_on_pdu( keyring: Keyring, room_version: RoomVersion, pdu: EventBase ) -> None: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 987f6dad46..7ee2974bb1 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -61,7 +61,7 @@ from synapse.federation.federation_base import ( ) from synapse.federation.transport.client import SendJoinResponse from synapse.http.types import QueryParams -from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace +from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache @@ -587,11 +587,15 @@ class FederationClient(FederationBase): Returns: A list of PDUs that have valid signatures and hashes. """ + set_tag( + SynapseTags.RESULT_PREFIX + "pdus.length", + str(len(pdus)), + ) # We limit how many PDUs we check at once, as if we try to do hundreds # of thousands of PDUs at once we see large memory spikes. - valid_pdus = [] + valid_pdus: List[EventBase] = [] async def _execute(pdu: EventBase) -> None: valid_pdu = await self._check_sigs_and_hash_and_fetch_one( @@ -607,6 +611,8 @@ class FederationClient(FederationBase): return valid_pdus + @trace + @tag_args async def _check_sigs_and_hash_and_fetch_one( self, pdu: EventBase, @@ -639,16 +645,27 @@ class FederationClient(FederationBase): except InvalidEventSignatureError as e: logger.warning( "Signature on retrieved event %s was invalid (%s). " - "Checking local store/orgin server", + "Checking local store/origin server", pdu.event_id, e, ) + log_kv( + { + "message": "Signature on retrieved event was invalid. " + "Checking local store/origin server", + "event_id": pdu.event_id, + "InvalidEventSignatureError": e, + } + ) # Check local db. res = await self.store.get_event( pdu.event_id, allow_rejected=True, allow_none=True ) + # If the PDU fails its signature check and we don't have it in our + # database, we then request it from sender's server (if that is not the + # same as `origin`). pdu_origin = get_domain_from_id(pdu.sender) if not res and pdu_origin != origin: try: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 75fbc6073d..3bf84cf625 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -763,6 +763,17 @@ class FederationServer(FederationBase): The partial knock event. """ origin_host, _ = parse_server_name(origin) + + if await self.store.is_partial_state_room(room_id): + # Before we do anything: check if the room is partial-stated. + # Note that at the time this check was added, `on_make_knock_request` would + # block due to https://github.com/matrix-org/synapse/issues/12997. + raise SynapseError( + 404, + "Unable to handle /make_knock right now; this server is not fully joined.", + errcode=Codes.NOT_FOUND, + ) + await self.check_server_matches_acl(origin_host, room_id) room_version = await self.store.get_room_version(room_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index f5c586f657..9c2c3a0e68 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -310,6 +310,7 @@ class DeviceHandler(DeviceWorkerHandler): super().__init__(hs) self.federation_sender = hs.get_federation_sender() + self._storage_controllers = hs.get_storage_controllers() self.device_list_updater = DeviceListUpdater(hs, self) @@ -694,8 +695,11 @@ class DeviceHandler(DeviceWorkerHandler): # Ignore any users that aren't ours if self.hs.is_mine_id(user_id): - joined_user_ids = await self.store.get_users_in_room(room_id) - hosts = {get_domain_from_id(u) for u in joined_user_ids} + hosts = set( + await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) + ) hosts.discard(self.server_name) # Check if we've already sent this update to some hosts diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 948f66a94d..7127d5aefc 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -30,7 +30,7 @@ from synapse.api.errors import ( from synapse.appservice import ApplicationService from synapse.module_api import NOT_SPAM from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import JsonDict, Requester, RoomAlias, get_domain_from_id +from synapse.types import JsonDict, Requester, RoomAlias if TYPE_CHECKING: from synapse.server import HomeServer @@ -83,8 +83,9 @@ class DirectoryHandler: # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - users = await self.store.get_users_in_room(room_id) - servers = {get_domain_from_id(u) for u in users} + servers = await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) if not servers: raise SynapseError(400, "Failed to get server list") @@ -287,8 +288,9 @@ class DirectoryHandler: Codes.NOT_FOUND, ) - users = await self.store.get_users_in_room(room_id) - extra_servers = {get_domain_from_id(u) for u in users} + extra_servers = await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) servers_set = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index a2dd9c7efa..c3ddc5d182 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -129,12 +129,9 @@ class EventAuthHandler: else: users = {} - # Find the user with the highest power level. - users_in_room = await self._store.get_users_in_room(room_id) - # Only interested in local users. - local_users_in_room = [ - u for u in users_in_room if get_domain_from_id(u) == self._server_name - ] + # Find the user with the highest power level (only interested in local + # users). + local_users_in_room = await self._store.get_local_users_in_room(room_id) chosen_user = max( local_users_in_room, key=lambda user: users.get(user, users_default_level), diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index ac13340d3a..949b69cb41 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -151,7 +151,7 @@ class EventHandler: """Retrieve a single specified event. Args: - user: The user requesting the event + user: The local user requesting the event room_id: The expected room id. We'll return None if the event's room does not match. event_id: The event ID to obtain. @@ -173,8 +173,11 @@ class EventHandler: if not event: return None - users = await self.store.get_users_in_room(event.room_id) - is_peeking = user.to_string() not in users + is_user_in_room = await self.store.check_local_user_in_room( + user_id=user.to_string(), room_id=event.room_id + ) + # The user is peeking if they aren't in the room already + is_peeking = not is_user_in_room filtered = await filter_events_for_client( self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e151962055..dd4b9f66d1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -70,7 +70,7 @@ from synapse.replication.http.federation import ( from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import JsonDict, StateMap, get_domain_from_id +from synapse.types import JsonDict, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.visibility import filter_events_for_server @@ -104,37 +104,6 @@ backfill_processing_before_timer = Histogram( ) -def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: - """Get joined domains from state - - Args: - state: State map from type/state key to event. - - Returns: - Returns a list of servers with the lowest depth of their joins. - Sorted by lowest depth first. - """ - joined_users = [ - (state_key, int(event.depth)) - for (e_type, state_key), event in state.items() - if e_type == EventTypes.Member and event.membership == Membership.JOIN - ] - - joined_domains: Dict[str, int] = {} - for u, d in joined_users: - try: - dom = get_domain_from_id(u) - old_d = joined_domains.get(dom) - if old_d: - joined_domains[dom] = min(d, old_d) - else: - joined_domains[dom] = d - except Exception: - pass - - return sorted(joined_domains.items(), key=lambda d: d[1]) - - class _BackfillPointType(Enum): # a regular backwards extremity (ie, an event which we don't yet have, but which # is referred to by other events in the DAG) @@ -432,21 +401,19 @@ class FederationHandler: ) # Now we need to decide which hosts to hit first. - - # First we try hosts that are already in the room + # First we try hosts that are already in the room. # TODO: HEURISTIC ALERT. + likely_domains = ( + await self._storage_controllers.state.get_current_hosts_in_room(room_id) + ) - curr_state = await self._storage_controllers.state.get_current_state(room_id) - - curr_domains = get_domains_from_state(curr_state) - - likely_domains = [ - domain for domain, depth in curr_domains if domain != self.server_name - ] - - async def try_backfill(domains: List[str]) -> bool: + async def try_backfill(domains: Collection[str]) -> bool: # TODO: Should we try multiple of these at a time? for dom in domains: + # We don't want to ask our own server for information we don't have + if dom == self.server_name: + continue + try: await self._federation_event_handler.backfill( dom, room_id, limit=100, extremities=extremities_to_request diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 048c4111f6..ace7adcffb 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1041,6 +1041,14 @@ class FederationEventHandler: InvalidResponseError: if the remote homeserver's response contains fields of the wrong type. """ + + # It would be better if we could query the difference from our known + # state to the given `event_id` so the sending server doesn't have to + # send as much and we don't have to process as many events. For example + # in a room like #matrix:matrix.org, we get 200k events (77k state_events, 122k + # auth_events) from this call. + # + # Tracked by https://github.com/matrix-org/synapse/issues/13618 ( state_event_ids, auth_event_ids, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index acd3de06f6..72157d5a36 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -761,8 +761,10 @@ class EventCreationHandler: async def _is_server_notices_room(self, room_id: str) -> bool: if self.config.servernotices.server_notices_mxid is None: return False - user_ids = await self.store.get_users_in_room(room_id) - return self.config.servernotices.server_notices_mxid in user_ids + is_server_notices_room = await self.store.check_local_user_in_room( + user_id=self.config.servernotices.server_notices_mxid, room_id=room_id + ) + return is_server_notices_room async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 74e944bce7..a0c39778ab 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -159,11 +159,9 @@ class PaginationHandler: self._retention_allowed_lifetime_max = ( hs.config.retention.retention_allowed_lifetime_max ) + self._is_master = hs.config.worker.worker_app is None - if ( - hs.config.worker.run_background_tasks - and hs.config.retention.retention_enabled - ): + if hs.config.retention.retention_enabled and self._is_master: # Run the purge jobs described in the configuration file. for job in hs.config.retention.retention_purge_jobs: logger.info("Setting up purge job with config: %s", job) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 741504ba9f..4e575ffbaa 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -2051,8 +2051,7 @@ async def get_interested_remotes( ) for room_id, states in room_ids_to_states.items(): - user_ids = await store.get_users_in_room(room_id) - hosts = {get_domain_from_id(user_id) for user_id in user_ids} + hosts = await store.get_current_hosts_in_room(room_id) for host in hosts: hosts_and_states.setdefault(host, set()).update(states) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2bf0ebd025..f64a8690a5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -60,7 +60,6 @@ from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.federation.federation_client import InvalidResponseError -from synapse.handlers.federation import get_domains_from_state from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin @@ -1284,8 +1283,11 @@ class RoomContextHandler: before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit - users = await self.store.get_users_in_room(room_id) - is_peeking = user.to_string() not in users + is_user_in_room = await self.store.check_local_user_in_room( + user_id=user.to_string(), room_id=room_id + ) + # The user is peeking if they aren't in the room already + is_peeking = not is_user_in_room async def filter_evts(events: List[EventBase]) -> List[EventBase]: if use_admin_priviledge: @@ -1459,17 +1461,16 @@ class TimestampLookupHandler: timestamp, ) - # Find other homeservers from the given state in the room - curr_state = await self._storage_controllers.state.get_current_state( - room_id + likely_domains = ( + await self._storage_controllers.state.get_current_hosts_in_room(room_id) ) - curr_domains = get_domains_from_state(curr_state) - likely_domains = [ - domain for domain, depth in curr_domains if domain != self.server_name - ] # Loop through each homeserver candidate until we get a succesful response for domain in likely_domains: + # We don't want to ask our own server for information we don't have + if domain == self.server_name: + continue + try: remote_response = await self.federation_client.timestamp_to_event( domain, room_id, timestamp, direction diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 65b9a655d4..e726997d83 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1620,8 +1620,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def _is_server_notice_room(self, room_id: str) -> bool: if self._server_notices_mxid is None: return False - user_ids = await self.store.get_users_in_room(room_id) - return self._server_notices_mxid in user_ids + is_server_notices_room = await self.store.check_local_user_in_room( + user_id=self._server_notices_mxid, room_id=room_id + ) + return is_server_notices_room class RoomMemberMasterHandler(RoomMemberHandler): @@ -1923,8 +1925,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): ]: raise SynapseError(400, "User %s in room %s" % (user_id, room_id)) - if membership: - await self.store.forget(user_id, room_id) + # In normal case this call is only required if `membership` is not `None`. + # But: After the last member had left the room, the background update + # `_background_remove_left_rooms` is deleting rows related to this room from + # the table `current_state_events` and `get_current_state_events` is `None`. + await self.store.forget(user_id, room_id) def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b4d3f3958c..2d95b1fa24 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -2421,10 +2421,10 @@ class SyncHandler: joined_room.room_id, joined_room.event_pos.stream ) ) - users_in_room = await self.state.get_current_users_in_room( + user_ids_in_room = await self.state.get_current_user_ids_in_room( joined_room.room_id, extrems ) - if user_id in users_in_room: + if user_id in user_ids_in_room: joined_room_ids.add(joined_room.room_id) return frozenset(joined_room_ids) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index bcac3372a2..a4cd8b8f0c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.replication.tcp.streams import TypingStream from synapse.streams import EventSource -from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, StreamKeyType, UserID from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -362,8 +362,9 @@ class TypingWriterHandler(FollowerTypingHandler): ) return - users = await self.store.get_users_in_room(room_id) - domains = {get_domain_from_id(u) for u in users} + domains = await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) if self.server_name in domains: logger.info("Got typing update from %s: %r", user_id, content) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 496fce2ecc..c3d3daf877 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -46,12 +46,12 @@ from twisted.python.threadpool import ThreadPool # This module is imported for its side effects; flake8 needn't warn that it's unused. import synapse.metrics._reactor_metrics # noqa: F401 -from synapse.metrics._exposition import ( +from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager +from synapse.metrics._legacy_exposition import ( MetricsResource, generate_latest, start_http_server, ) -from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager from synapse.metrics._types import Collector from synapse.util import SYNAPSE_VERSION diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_legacy_exposition.py index 353d0a63b6..ff640a49af 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_legacy_exposition.py @@ -80,7 +80,27 @@ def sample_line(line: Sample, name: str) -> str: return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) +# Mapping from new metric names to legacy metric names. +# We translate these back to their old names when exposing them through our +# legacy vendored exporter. +# Only this legacy exposition module applies these name changes. +LEGACY_METRIC_NAMES = { + "synapse_util_caches_cache_hits": "synapse_util_caches_cache:hits", + "synapse_util_caches_cache_size": "synapse_util_caches_cache:size", + "synapse_util_caches_cache_evicted_size": "synapse_util_caches_cache:evicted_size", + "synapse_util_caches_cache_total": "synapse_util_caches_cache:total", + "synapse_util_caches_response_cache_size": "synapse_util_caches_response_cache:size", + "synapse_util_caches_response_cache_hits": "synapse_util_caches_response_cache:hits", + "synapse_util_caches_response_cache_evicted_size": "synapse_util_caches_response_cache:evicted_size", + "synapse_util_caches_response_cache_total": "synapse_util_caches_response_cache:total", +} + + def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes: + """ + Generate metrics in legacy format. Modern metrics are generated directly + by prometheus-client. + """ # Trigger the cache metrics to be rescraped, which updates the common # metrics but do not produce metrics themselves @@ -94,7 +114,8 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt # No samples, don't bother. continue - mname = metric.name + # Translate to legacy metric name if it has one. + mname = LEGACY_METRIC_NAMES.get(metric.name, metric.name) mnewname = metric.name mtype = metric.type @@ -124,7 +145,7 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt om_samples: Dict[str, List[str]] = {} for s in metric.samples: for suffix in ["_created", "_gsum", "_gcount"]: - if s.name == metric.name + suffix: + if s.name == mname + suffix: # OpenMetrics specific sample, put in a gauge at the end. # (these come from gaugehistograms which don't get renamed, # so no need to faff with mnewname) @@ -140,12 +161,12 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt if emit_help: output.append( "# HELP {}{} {}\n".format( - metric.name, + mname, suffix, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append(f"# TYPE {metric.name}{suffix} gauge\n") + output.append(f"# TYPE {mname}{suffix} gauge\n") output.extend(lines) # Get rid of the weird colon things while we're at it @@ -170,11 +191,12 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt # Get rid of the OpenMetrics specific samples (we should already have # dealt with them above anyway.) for suffix in ["_created", "_gsum", "_gcount"]: - if s.name == metric.name + suffix: + if s.name == mname + suffix: break else: + sample_name = LEGACY_METRIC_NAMES.get(s.name, s.name) output.append( - sample_line(s, s.name.replace(":total", "").replace(":", "_")) + sample_line(s, sample_name.replace(":total", "").replace(":", "_")) ) return "".join(output).encode("utf-8") diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 52ee3f7e58..5e65eaf1e0 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -31,6 +31,5 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) - self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index c35d42fab8..d30878f704 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -254,30 +254,32 @@ async def respond_with_responder( file_size: Size in bytes of the media. If not known it should be None upload_name: The name of the requested file, if any. """ - if request._disconnected: - logger.warning( - "Not sending response to request %s, already disconnected.", request - ) - return - if not responder: respond_404(request) return - logger.debug("Responding to media request with responder %s", responder) - add_file_headers(request, media_type, file_size, upload_name) - try: - with responder: + # If we have a responder we *must* use it as a context manager. + with responder: + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return + + logger.debug("Responding to media request with responder %s", responder) + add_file_headers(request, media_type, file_size, upload_name) + try: + await responder.write_to_consumer(request) - except Exception as e: - # The majority of the time this will be due to the client having gone - # away. Unfortunately, Twisted simply throws a generic exception at us - # in that case. - logger.warning("Failed to write to consumer: %s %s", type(e), e) - - # Unregister the producer, if it has one, so Twisted doesn't complain - if request.producer: - request.unregisterProducer() + except Exception as e: + # The majority of the time this will be due to the client having gone + # away. Unfortunately, Twisted simply throws a generic exception at us + # in that case. + logger.warning("Failed to write to consumer: %s %s", type(e), e) + + # Unregister the producer, if it has one, so Twisted doesn't complain + if request.producer: + request.unregisterProducer() finish_request(request) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index b36c98a08e..a8f6fd6b35 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -732,10 +732,6 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("Running url preview cache expiry") - if not (await self.store.db_pool.updates.has_completed_background_updates()): - logger.debug("Still running DB updates; skipping url preview cache expiry") - return - def try_remove_parent_dirs(dirs: Iterable[str]) -> None: """Attempt to remove the given chain of parent directories diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 70d054a8f4..564e3705c2 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -102,6 +102,10 @@ class ServerNoticesManager: Returns: The room's ID, or None if no room could be found. """ + # If there is no server notices MXID, then there is no server notices room + if self.server_notices_mxid is None: + return None + rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) @@ -111,8 +115,10 @@ class ServerNoticesManager: # be joined. This is kinda deliberate, in that if somebody somehow # manages to invite the system user to a room, that doesn't make it # the server notices room. - user_ids = await self._store.get_users_in_room(room.room_id) - if len(user_ids) <= 2 and self.server_notices_mxid in user_ids: + is_server_notices_room = await self._store.check_local_user_in_room( + user_id=self.server_notices_mxid, room_id=room.room_id + ) + if is_server_notices_room: # we found a room which our user shares with the system notice # user return room.room_id diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index c355e4f98a..3047e1b1ad 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -44,7 +44,6 @@ from synapse.logging.context import ContextResourceUsage from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.roommember import ProfileInfo from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util.async_helpers import Linearizer @@ -210,11 +209,11 @@ class StateHandler: ret = await self.resolve_state_groups_for_events(room_id, event_ids) return await ret.get_state(self._state_storage_controller, state_filter) - async def get_current_users_in_room( + async def get_current_user_ids_in_room( self, room_id: str, latest_event_ids: List[str] - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: """ - Get the users who are currently in a room. + Get the users IDs who are currently in a room. Note: This is much slower than using the equivalent method `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`, @@ -225,15 +224,15 @@ class StateHandler: room_id: The ID of the room. latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. Returns: - Dictionary of user IDs to their profileinfo. + Set of user IDs in the room. """ assert latest_event_ids is not None - logger.debug("calling resolve_state_groups from get_current_users_in_room") + logger.debug("calling resolve_state_groups from get_current_user_ids_in_room") entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = await entry.get_state(self._state_storage_controller, StateFilter.all()) - return await self.store.get_joined_users_from_state(room_id, state, entry) + return await self.store.get_joined_user_ids_from_state(room_id, state, entry) async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] diff --git a/synapse/state/v2.py b/synapse/state/v2.py index cf3045f82e..af03851c71 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -271,40 +271,41 @@ async def _get_power_level_for_sender( async def _get_auth_chain_difference( room_id: str, state_sets: Sequence[Mapping[Any, str]], - event_map: Dict[str, EventBase], + unpersisted_events: Dict[str, EventBase], state_res_store: StateResolutionStore, ) -> Set[str]: """Compare the auth chains of each state set and return the set of events - that only appear in some but not all of the auth chains. + that only appear in some, but not all of the auth chains. Args: - state_sets - event_map - state_res_store + state_sets: The input state sets we are trying to resolve across. + unpersisted_events: A map from event ID to EventBase containing all unpersisted + events involved in this resolution. + state_res_store: Returns: - Set of event IDs + The auth difference of the given state sets, as a set of event IDs. """ # The `StateResolutionStore.get_auth_chain_difference` function assumes that # all events passed to it (and their auth chains) have been persisted - # previously. This is not the case for any events in the `event_map`, and so - # we need to manually handle those events. + # previously. We need to manually handle any other events that are yet to be + # persisted. # - # We do this by: - # 1. calculating the auth chain difference for the state sets based on the - # events in `event_map` alone - # 2. replacing any events in the state_sets that are also in `event_map` - # with their auth events (recursively), and then calling - # `store.get_auth_chain_difference` as normal - # 3. adding the results of 1 and 2 together. - - # Map from event ID in `event_map` to their auth event IDs, and their auth - # event IDs if they appear in the `event_map`. This is the intersection of - # the event's auth chain with the events in the `event_map` *plus* their + # We do this in three steps: + # 1. Compute the set of unpersisted events belonging to the auth difference. + # 2. Replacing any unpersisted events in the state_sets with their auth events, + # recursively, until the state_sets contain only persisted events. + # Then we call `store.get_auth_chain_difference` as normal, which computes + # the set of persisted events belonging to the auth difference. + # 3. Adding the results of 1 and 2 together. + + # Map from event ID in `unpersisted_events` to their auth event IDs, and their auth + # event IDs if they appear in the `unpersisted_events`. This is the intersection of + # the event's auth chain with the events in `unpersisted_events` *plus* their # auth event IDs. events_to_auth_chain: Dict[str, Set[str]] = {} - for event in event_map.values(): + for event in unpersisted_events.values(): chain = {event.event_id} events_to_auth_chain[event.event_id] = chain @@ -312,16 +313,16 @@ async def _get_auth_chain_difference( while to_search: for auth_id in to_search.pop().auth_event_ids(): chain.add(auth_id) - auth_event = event_map.get(auth_id) + auth_event = unpersisted_events.get(auth_id) if auth_event: to_search.append(auth_event) - # We now a) calculate the auth chain difference for the unpersisted events - # and b) work out the state sets to pass to the store. + # We now 1) calculate the auth chain difference for the unpersisted events + # and 2) work out the state sets to pass to the store. # - # Note: If the `event_map` is empty (which is the common case), we can do a + # Note: If there are no `unpersisted_events` (which is the common case), we can do a # much simpler calculation. - if event_map: + if unpersisted_events: # The list of state sets to pass to the store, where each state set is a set # of the event ids making up the state. This is similar to `state_sets`, # except that (a) we only have event ids, not the complete @@ -344,14 +345,18 @@ async def _get_auth_chain_difference( for event_id in state_set.values(): event_chain = events_to_auth_chain.get(event_id) if event_chain is not None: - # We have an event in `event_map`. We add all the auth - # events that it references (that aren't also in `event_map`). - set_ids.update(e for e in event_chain if e not in event_map) + # We have an unpersisted event. We add all the auth + # events that it references which are also unpersisted. + set_ids.update( + e for e in event_chain if e not in unpersisted_events + ) # We also add the full chain of unpersisted event IDs # referenced by this state set, so that we can work out the # auth chain difference of the unpersisted events. - unpersisted_ids.update(e for e in event_chain if e in event_map) + unpersisted_ids.update( + e for e in event_chain if e in unpersisted_events + ) else: set_ids.add(event_id) @@ -361,15 +366,15 @@ async def _get_auth_chain_difference( union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) - difference_from_event_map: Collection[str] = union - intersection + auth_difference_unpersisted_part: Collection[str] = union - intersection else: - difference_from_event_map = () + auth_difference_unpersisted_part = () state_sets_ids = [set(state_set.values()) for state_set in state_sets] difference = await state_res_store.get_auth_chain_difference( room_id, state_sets_ids ) - difference.update(difference_from_event_map) + difference.update(auth_difference_unpersisted_part) return difference diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index f9ffd0e29e..ba5380ce3e 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -23,7 +23,6 @@ from typing import ( List, Mapping, Optional, - Set, Tuple, ) @@ -520,7 +519,7 @@ class StateStorageController: ) return state_map.get(key) - async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + async def get_current_hosts_in_room(self, room_id: str) -> List[str]: """Get current hosts in room based on current state.""" await self._partial_state_room_tracker.await_full_state(room_id) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 9af9f4f18e..c38b8a9e5a 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) txn, self.get_account_data_for_room, (user_id,) ) self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,)) - self._invalidate_cache_and_stream( - txn, self.get_push_rules_enabled_for_user, (user_id,) - ) # This user might be contained in the ignored_by cache for other users, # so we have to invalidate it all. self._invalidate_all_cache_and_stream(txn, self.ignored_by) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index eabf9c9739..9f410d69de 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -459,6 +459,32 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return await self.db_pool.runInteraction("get_push_action_users_in_range", f) + def _get_receipts_by_room_txn( + self, txn: LoggingTransaction, user_id: str + ) -> List[Tuple[str, int]]: + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + ( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), + ) + + sql = f""" + SELECT room_id, MAX(stream_ordering) + FROM receipts_linearized + INNER JOIN events USING (room_id, event_id) + WHERE {receipt_types_clause} + AND user_id = ? + GROUP BY room_id + """ + + args.extend((user_id,)) + txn.execute(sql, args) + return cast(List[Tuple[str, int]], txn.fetchall()) + async def get_unread_push_actions_for_user_in_range_for_http( self, user_id: str, @@ -482,106 +508,45 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - # find rooms that have a read receipt in them and return the next - # push actions - def get_after_receipt( - txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool]]: - # find rooms that have a read receipt in them and return the next - # push actions - - receipt_types_clause, args = make_in_list_sql_clause( - self.database_engine, - "receipt_type", - ( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ), - ) - - sql = f""" - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, - ep.highlight - FROM ( - SELECT room_id, - MAX(stream_ordering) as stream_ordering - FROM events - INNER JOIN receipts_linearized USING (room_id, event_id) - WHERE {receipt_types_clause} AND user_id = ? - GROUP BY room_id - ) AS rl, - event_push_actions AS ep - WHERE - ep.room_id = rl.room_id - AND ep.stream_ordering > rl.stream_ordering - AND ep.user_id = ? - AND ep.stream_ordering > ? - AND ep.stream_ordering <= ? - AND ep.notif = 1 - ORDER BY ep.stream_ordering ASC LIMIT ? - """ - args.extend( - (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) - ) - txn.execute(sql, args) - return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) - - after_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt + receipts_by_room = dict( + await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, + ), ) - # There are rooms with push actions in them but you don't have a read receipt in - # them e.g. rooms you've been invited to, so get push actions for rooms which do - # not have read receipts in them too. - def get_no_receipt( + def get_push_actions_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool]]: - receipt_types_clause, args = make_in_list_sql_clause( - self.database_engine, - "receipt_type", - ( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ), - ) - - sql = f""" - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, - ep.highlight + sql = """ + SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight FROM event_push_actions AS ep - INNER JOIN events AS e USING (room_id, event_id) WHERE - ep.room_id NOT IN ( - SELECT room_id FROM receipts_linearized - WHERE {receipt_types_clause} AND user_id = ? - GROUP BY room_id - ) - AND ep.user_id = ? + ep.user_id = ? AND ep.stream_ordering > ? AND ep.stream_ordering <= ? AND ep.notif = 1 ORDER BY ep.stream_ordering ASC LIMIT ? """ - args.extend( - (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) - ) - txn.execute(sql, args) + txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) - no_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt + push_actions = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn ) notifs = [ HttpPushAction( - event_id=row[0], - room_id=row[1], - stream_ordering=row[2], - actions=_deserialize_action(row[3], row[4]), + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + actions=_deserialize_action(actions, highlight), ) - for row in after_read_receipt + no_read_receipt + for event_id, room_id, stream_ordering, actions, highlight in push_actions + # Only include push actions with a stream ordering after any receipt, or without any + # receipt present (invited to but never read rooms). + if stream_ordering > receipts_by_room.get(room_id, 0) ] # Now sort it so it's ordered correctly, since currently it will @@ -617,106 +582,49 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - # find rooms that have a read receipt in them and return the most recent - # push actions - def get_after_receipt( - txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool, int]]: - receipt_types_clause, args = make_in_list_sql_clause( - self.database_engine, - "receipt_type", - ( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ), - ) - - sql = f""" - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, - ep.highlight, e.received_ts - FROM ( - SELECT room_id, - MAX(stream_ordering) as stream_ordering - FROM events - INNER JOIN receipts_linearized USING (room_id, event_id) - WHERE {receipt_types_clause} AND user_id = ? - GROUP BY room_id - ) AS rl, - event_push_actions AS ep - INNER JOIN events AS e USING (room_id, event_id) - WHERE - ep.room_id = rl.room_id - AND ep.stream_ordering > rl.stream_ordering - AND ep.user_id = ? - AND ep.stream_ordering > ? - AND ep.stream_ordering <= ? - AND ep.notif = 1 - ORDER BY ep.stream_ordering DESC LIMIT ? - """ - args.extend( - (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) - ) - txn.execute(sql, args) - return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) - - after_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt + receipts_by_room = dict( + await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, + ), ) - # There are rooms with push actions in them but you don't have a read receipt in - # them e.g. rooms you've been invited to, so get push actions for rooms which do - # not have read receipts in them too. - def get_no_receipt( + def get_push_actions_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool, int]]: - receipt_types_clause, args = make_in_list_sql_clause( - self.database_engine, - "receipt_type", - ( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ), - ) - - sql = f""" + sql = """ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight, e.received_ts FROM event_push_actions AS ep INNER JOIN events AS e USING (room_id, event_id) WHERE - ep.room_id NOT IN ( - SELECT room_id FROM receipts_linearized - WHERE {receipt_types_clause} AND user_id = ? - GROUP BY room_id - ) - AND ep.user_id = ? + ep.user_id = ? AND ep.stream_ordering > ? AND ep.stream_ordering <= ? AND ep.notif = 1 ORDER BY ep.stream_ordering DESC LIMIT ? """ - args.extend( - (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) - ) - txn.execute(sql, args) + txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) - no_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt + push_actions = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn ) # Make a list of dicts from the two sets of results. notifs = [ EmailPushAction( - event_id=row[0], - room_id=row[1], - stream_ordering=row[2], - actions=_deserialize_action(row[3], row[4]), - received_ts=row[5], + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + actions=_deserialize_action(actions, highlight), + received_ts=received_ts, ) - for row in after_read_receipt + no_read_receipt + for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions + # Only include push actions with a stream ordering after any receipt, or without any + # receipt present (invited to but never read rooms). + if stream_ordering > receipts_by_room.get(room_id, 0) ] # Now sort it so it's ordered correctly, since currently it will @@ -792,26 +700,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas int(count_as_unread), # unread column ) - def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None: - # We don't use simple_insert_many here to avoid the overhead - # of generating lists of dicts. - - sql = """ - INSERT INTO event_push_actions_staging - (event_id, user_id, actions, notif, highlight, unread) - VALUES (?, ?, ?, ?, ?, ?) - """ - - txn.execute_batch( - sql, - ( - _gen_entry(user_id, actions) - for user_id, actions in user_id_actions.items() - ), - ) - - return await self.db_pool.runInteraction( - "add_push_actions_to_staging", _add_push_actions_to_staging_txn + await self.db_pool.simple_insert_many( + "event_push_actions_staging", + keys=("event_id", "user_id", "actions", "notif", "highlight", "unread"), + values=[ + _gen_entry(user_id, actions) + for user_id, actions in user_id_actions.items() + ], + desc="add_push_actions_to_staging", ) async def remove_push_actions_from_staging(self, event_id: str) -> None: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 255620f996..5079edd1e0 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -165,7 +165,6 @@ class PushRulesWorkerStore( return _load_rules(rows, enabled_map, self.hs.config.experimental) - @cached(max_entries=5000) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( table="push_rules_enable", @@ -229,9 +228,6 @@ class PushRulesWorkerStore( return results - @cachedList( - cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids" - ) async def bulk_get_push_rules_enabled( self, user_ids: Collection[str] ) -> Dict[str, Dict[str, bool]]: @@ -246,6 +242,7 @@ class PushRulesWorkerStore( iterable=user_ids, retcols=("user_name", "rule_id", "enabled"), desc="bulk_get_push_rules_enabled", + batch_size=1000, ) for row in rows: enabled = bool(row["enabled"]) @@ -792,7 +789,6 @@ class PushRuleStore(PushRulesWorkerStore): self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) - txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) txn.call_after( self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 827c1f1efd..06500457bd 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -187,27 +187,48 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000, iterable=True) async def get_users_in_room(self, room_id: str) -> List[str]: + """ + Returns a list of users in the room sorted by longest in the room first + (aka. with the lowest depth). This is done to match the sort in + `get_current_hosts_in_room()` and so we can re-use the cache but it's + not horrible to have here either. + """ + return await self.db_pool.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]: + """ + Returns a list of users in the room sorted by longest in the room first + (aka. with the lowest depth). This is done to match the sort in + `get_current_hosts_in_room()` and so we can re-use the cache but it's + not horrible to have here either. + """ # If we can assume current_state_events.membership is up to date # then we can avoid a join, which is a Very Good Thing given how # frequently this function gets called. if self._current_state_events_membership_up_to_date: sql = """ - SELECT state_key FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? AND membership = ? + SELECT c.state_key FROM current_state_events as c + /* Get the depth of the event from the events table */ + INNER JOIN events AS e USING (event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ? + /* Sorted by lowest depth first */ + ORDER BY e.depth ASC; """ else: sql = """ - SELECT state_key FROM room_memberships as m + SELECT c.state_key FROM room_memberships as m + /* Get the depth of the event from the events table */ + INNER JOIN events AS e USING (event_id) INNER JOIN current_state_events as c ON m.event_id = c.event_id AND m.room_id = c.room_id AND m.user_id = c.state_key WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? + /* Sorted by lowest depth first */ + ORDER BY e.depth ASC; """ txn.execute(sql, (room_id, Membership.JOIN)) @@ -534,6 +555,32 @@ class RoomMemberWorkerStore(EventsWorkerStore): desc="get_local_users_in_room", ) + async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool: + """ + Check whether a given local user is currently joined to the given room. + + Returns: + A boolean indicating whether the user is currently joined to the room + + Raises: + Exeption when called with a non-local user to this homeserver + """ + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'check_local_user_in_room' on " + "non-local user %s" % (user_id,), + ) + + ( + membership, + member_event_id, + ) = await self.get_local_current_membership_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) + + return membership == Membership.JOIN + async def get_local_current_membership_for_user_in_room( self, user_id: str, room_id: str ) -> Tuple[Optional[str], Optional[str]]: @@ -835,9 +882,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): return shared_room_ids or frozenset() - async def get_joined_users_from_state( + async def get_joined_user_ids_from_state( self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry" - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: state_group: Union[object, int] = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -848,25 +895,25 @@ class RoomMemberWorkerStore(EventsWorkerStore): assert state_group is not None with Measure(self._clock, "get_joined_users_from_state"): - return await self._get_joined_users_from_context( + return await self._get_joined_user_ids_from_context( room_id, state_group, state, context=state_entry ) @cached(num_args=2, iterable=True, max_entries=100000) - async def _get_joined_users_from_context( + async def _get_joined_user_ids_from_context( self, room_id: str, state_group: Union[object, int], current_state_ids: StateMap[str], event: Optional[EventBase] = None, context: Optional["_StateCacheEntry"] = None, - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. assert state_group is not None - users_in_room = {} + users_in_room = set() member_event_ids = [ e_id for key, e_id in current_state_ids.items() @@ -879,11 +926,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): # If we do then we can reuse that result and simply update it with # any membership changes in `delta_ids` if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get_immediate( + prev_res = self._get_joined_user_ids_from_context.cache.get_immediate( (room_id, context.prev_group), None ) - if prev_res and isinstance(prev_res, dict): - users_in_room = dict(prev_res) + if prev_res and isinstance(prev_res, set): + users_in_room = prev_res member_event_ids = [ e_id for key, e_id in context.delta_ids.items() @@ -891,7 +938,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ] for etype, state_key in context.delta_ids: if etype == EventTypes.Member: - users_in_room.pop(state_key, None) + users_in_room.discard(state_key) # We check if we have any of the member event ids in the event cache # before we ask the DB @@ -908,71 +955,64 @@ class RoomMemberWorkerStore(EventsWorkerStore): ev_entry = event_map.get(event_id) if ev_entry and not ev_entry.event.rejected_reason: if ev_entry.event.membership == Membership.JOIN: - users_in_room[ev_entry.event.state_key] = ProfileInfo( - display_name=ev_entry.event.content.get("displayname", None), - avatar_url=ev_entry.event.content.get("avatar_url", None), - ) + users_in_room.add(ev_entry.event.state_key) else: missing_member_event_ids.append(event_id) if missing_member_event_ids: - event_to_memberships = await self._get_joined_profiles_from_event_ids( + event_to_memberships = await self._get_user_ids_from_membership_event_ids( missing_member_event_ids ) - users_in_room.update(row for row in event_to_memberships.values() if row) + users_in_room.update( + user_id for user_id in event_to_memberships.values() if user_id + ) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: if event.event_id in member_event_ids: - users_in_room[event.state_key] = ProfileInfo( - display_name=event.content.get("displayname", None), - avatar_url=event.content.get("avatar_url", None), - ) + users_in_room.add(event.state_key) return users_in_room - @cached(max_entries=10000) - def _get_joined_profile_from_event_id( + @cached( + max_entries=10000, + # This name matches the old function that has been replaced - the cache name + # is kept here to maintain backwards compatibility. + name="_get_joined_profile_from_event_id", + ) + def _get_user_id_from_membership_event_id( self, event_id: str ) -> Optional[Tuple[str, ProfileInfo]]: raise NotImplementedError() @cachedList( - cached_method_name="_get_joined_profile_from_event_id", + cached_method_name="_get_user_id_from_membership_event_id", list_name="event_ids", ) - async def _get_joined_profiles_from_event_ids( + async def _get_user_ids_from_membership_event_ids( self, event_ids: Iterable[str] - ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]: + ) -> Dict[str, Optional[str]]: """For given set of member event_ids check if they point to a join - event and if so return the associated user and profile info. + event. Args: event_ids: The member event IDs to lookup Returns: - Map from event ID to `user_id` and ProfileInfo (or None if not join event). + Map from event ID to `user_id`, or None if event is not a join. """ rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, - retcols=("user_id", "display_name", "avatar_url", "event_id"), + retcols=("user_id", "event_id"), keyvalues={"membership": Membership.JOIN}, batch_size=1000, - desc="_get_joined_profiles_from_event_ids", + desc="_get_user_ids_from_membership_event_ids", ) - return { - row["event_id"]: ( - row["user_id"], - ProfileInfo( - avatar_url=row["avatar_url"], display_name=row["display_name"] - ), - ) - for row in rows - } + return {row["event_id"]: row["user_id"] for row in rows} @cached(max_entries=10000) async def is_host_joined(self, room_id: str, host: str) -> bool: @@ -1018,37 +1058,70 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True @cached(iterable=True, max_entries=10000) - async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: - """Get current hosts in room based on current state.""" + async def get_current_hosts_in_room(self, room_id: str) -> List[str]: + """ + Get current hosts in room based on current state. + + The heuristic of sorting by servers who have been in the room the + longest is good because they're most likely to have anything we ask + about. + + Returns: + Returns a list of servers sorted by longest in the room first. (aka. + sorted by join with the lowest depth first). + """ # First we check if we already have `get_users_in_room` in the cache, as # we can just calculate result from that users = self.get_users_in_room.cache.get_immediate( (room_id,), None, update_metrics=False ) - if users is not None: - return {get_domain_from_id(u) for u in users} - - if isinstance(self.database_engine, Sqlite3Engine): + if users is None and isinstance(self.database_engine, Sqlite3Engine): # If we're using SQLite then let's just always use # `get_users_in_room` rather than funky SQL. users = await self.get_users_in_room(room_id) - return {get_domain_from_id(u) for u in users} + + if users is not None: + # Because `users` is sorted from lowest -> highest depth, the list + # of domains will also be sorted that way. + domains: List[str] = [] + # We use a `Set` just for fast lookups + domain_set: Set[str] = set() + for u in users: + domain = get_domain_from_id(u) + if domain not in domain_set: + domain_set.add(domain) + domains.append(domain) + return domains # For PostgreSQL we can use a regex to pull out the domains from the # joined users in `current_state_events` via regex. - def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]: + def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]: + # Returns a list of servers currently joined in the room sorted by + # longest in the room first (aka. with the lowest depth). The + # heuristic of sorting by servers who have been in the room the + # longest is good because they're most likely to have anything we + # ask about. sql = """ - SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$') - FROM current_state_events + SELECT + /* Match the domain part of the MXID */ + substring(c.state_key FROM '@[^:]*:(.*)$') as server_domain + FROM current_state_events c + /* Get the depth of the event from the events table */ + INNER JOIN events AS e USING (event_id) WHERE - type = 'm.room.member' - AND membership = 'join' - AND room_id = ? + /* Find any join state events in the room */ + c.type = 'm.room.member' + AND c.membership = 'join' + AND c.room_id = ? + /* Group all state events from the same domain into their own buckets (groups) */ + GROUP BY server_domain + /* Sorted by lowest depth first */ + ORDER BY min(e.depth) ASC; """ txn.execute(sql, (room_id,)) - return {d for d, in txn} + return [d for d, in txn] return await self.db_pool.runInteraction( "get_current_hosts_in_room", get_current_hosts_in_room_txn @@ -1131,12 +1204,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): else: # The cache doesn't match the state group or prev state group, # so we calculate the result from first principles. - joined_users = await self.get_joined_users_from_state( + joined_user_ids = await self.get_joined_user_ids_from_state( room_id, state, state_entry ) cache.hosts_to_joined_users = {} - for user_id in joined_users: + for user_id in joined_user_ids: host = intern_string(get_domain_from_id(user_id)) cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 3c13859faa..2dfe4c0b66 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -460,8 +460,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): # Cast safety: this corresponds to the types returned by the query above. rows.extend(cast(Iterable[Tuple[str, int]], cur)) - # Sort so that we handle rows in order for each instance. - rows.sort() + # Sort by stream_id (ascending, lowest -> highest) so that we handle + # rows in order for each instance because we don't want to overwrite + # the current_position of an instance to a lower stream ID than + # we're actually at. + def sort_by_stream_id_key_func(row: Tuple[str, int]) -> int: + (instance, stream_id) = row + # If `stream_id` is ever `None`, we will see a `TypeError: '<' + # not supported between instances of 'NoneType' and 'X'` error. + return stream_id + + rows.sort(key=sort_by_stream_id_key_func) with self._lock: for ( diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 42f6abb5e1..bdf9b0dc8c 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -34,10 +34,10 @@ TRACK_MEMORY_USAGE = False caches_by_name: Dict[str, Sized] = {} collectors_by_name: Dict[str, "CacheMetric"] = {} -cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) -cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) -cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name", "reason"]) -cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"]) +cache_size = Gauge("synapse_util_caches_cache_size", "", ["name"]) +cache_hits = Gauge("synapse_util_caches_cache_hits", "", ["name"]) +cache_evicted = Gauge("synapse_util_caches_cache_evicted_size", "", ["name", "reason"]) +cache_total = Gauge("synapse_util_caches_cache_total", "", ["name"]) cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"]) cache_memory_usage = Gauge( "synapse_util_caches_cache_size_bytes", @@ -45,12 +45,12 @@ cache_memory_usage = Gauge( ["name"], ) -response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"]) -response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"]) +response_cache_size = Gauge("synapse_util_caches_response_cache_size", "", ["name"]) +response_cache_hits = Gauge("synapse_util_caches_response_cache_hits", "", ["name"]) response_cache_evicted = Gauge( - "synapse_util_caches_response_cache:evicted_size", "", ["name", "reason"] + "synapse_util_caches_response_cache_evicted_size", "", ["name", "reason"] ) -response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) +response_cache_total = Gauge("synapse_util_caches_response_cache_total", "", ["name"]) class EvictionReason(Enum): diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 1d6ec22191..6425f851ea 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -14,15 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import enum import threading from typing import ( Callable, + Collection, + Dict, Generic, - Iterable, MutableMapping, Optional, + Set, Sized, + Tuple, TypeVar, Union, cast, @@ -31,7 +35,6 @@ from typing import ( from prometheus_client import Gauge from twisted.internet import defer -from twisted.python import failure from twisted.python.failure import Failure from synapse.util.async_helpers import ObservableDeferred @@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache maps from the key value to a `CacheEntry` object. self._pending_deferred_cache: Union[ - TreeCache, "MutableMapping[KT, CacheEntry]" + TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]" ] = cache_type() def metrics_cb() -> None: @@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]): Raises: KeyError if the key is not found in the cache """ - callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) if val is not _Sentinel.sentinel: - val.callbacks.update(callbacks) + val.add_invalidation_callback(key, callback) if update_metrics: m = self.cache.metrics assert m # we always have a name, so should always have metrics m.inc_hits() - return val.deferred.observe() + return val.deferred(key) + + callbacks = (callback,) if callback else () val2 = self.cache.get( key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics @@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]): else: return defer.succeed(val2) + def get_bulk( + self, + keys: Collection[KT], + callback: Optional[Callable[[], None]] = None, + ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]: + """Bulk lookup of items in the cache. + + Returns: + A 3-tuple of: + 1. a dict of key/value of items already cached; + 2. a deferred that resolves to a dict of key/value of items + we're already fetching; and + 3. a collection of keys that don't appear in the previous two. + """ + + # The cached results + cached = {} + + # List of pending deferreds + pending = [] + + # Dict that gets filled out when the pending deferreds complete + pending_results = {} + + # List of keys that aren't in either cache + missing = [] + + callbacks = (callback,) if callback else () + + for key in keys: + # Check if its in the main cache. + immediate_value = self.cache.get( + key, + _Sentinel.sentinel, + callbacks=callbacks, + ) + if immediate_value is not _Sentinel.sentinel: + cached[key] = immediate_value + continue + + # Check if its in the pending cache + pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if pending_value is not _Sentinel.sentinel: + pending_value.add_invalidation_callback(key, callback) + + def completed_cb(value: VT, key: KT) -> VT: + pending_results[key] = value + return value + + # Add a callback to fill out `pending_results` when that completes + d = pending_value.deferred(key).addCallback(completed_cb, key) + pending.append(d) + continue + + # Not in either cache + missing.append(key) + + # If we've got pending deferreds, squash them into a single one that + # returns `pending_results`. + pending_deferred = None + if pending: + pending_deferred = defer.gatherResults( + pending, consumeErrors=True + ).addCallback(lambda _: pending_results) + + return (cached, pending_deferred, missing) + def get_immediate( self, key: KT, default: T, update_metrics: bool = True ) -> Union[VT, T]: @@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]): value: a deferred which will complete with a result to add to the cache callback: An optional callback to be called when the entry is invalidated """ - if not isinstance(value, defer.Deferred): - raise TypeError("not a Deferred") - - callbacks = [callback] if callback else [] self.check_thread() - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry: - existing_entry.invalidate() + self._pending_deferred_cache.pop(key, None) # XXX: why don't we invalidate the entry in `self.cache` yet? - # we can save a whole load of effort if the deferred is ready. - if value.called: - result = value.result - if not isinstance(result, failure.Failure): - self.cache.set(key, cast(VT, result), callbacks) - return value - # otherwise, we'll add an entry to the _pending_deferred_cache for now, # and add callbacks to add it to the cache properly later. + entry = CacheEntrySingle[KT, VT](value) + entry.add_invalidation_callback(key, callback) + self._pending_deferred_cache[key] = entry + deferred = entry.deferred(key).addCallbacks( + self._completed_callback, + self._error_callback, + callbackArgs=(entry, key), + errbackArgs=(entry, key), + ) - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) + # we return a new Deferred which will be called before any subsequent observers. + return deferred - self._pending_deferred_cache[key] = entry + def start_bulk_input( + self, + keys: Collection[KT], + callback: Optional[Callable[[], None]] = None, + ) -> "CacheMultipleEntries[KT, VT]": + """Bulk set API for use when fetching multiple keys at once from the DB. - def compare_and_pop() -> bool: - """Check if our entry is still the one in _pending_deferred_cache, and - if so, pop it. - - Returns true if the entries matched. - """ - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - return True - - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - - return False - - def cb(result: VT) -> None: - if compare_and_pop(): - self.cache.set(key, result, entry.callbacks) - else: - # we're not going to put this entry into the cache, so need - # to make sure that the invalidation callbacks are called. - # That was probably done when _pending_deferred_cache was - # updated, but it's possible that `set` was called without - # `invalidate` being previously called, in which case it may - # not have been. Either way, let's double-check now. - entry.invalidate() - - def eb(_fail: Failure) -> None: - compare_and_pop() - entry.invalidate() - - # once the deferred completes, we can move the entry from the - # _pending_deferred_cache to the real cache. - # - observer.addCallbacks(cb, eb) + Called *before* starting the fetch from the DB, and the caller *must* + call either `complete_bulk(..)` or `error_bulk(..)` on the return value. + """ - # we return a new Deferred which will be called before any subsequent observers. - return observable.observe() + entry = CacheMultipleEntries[KT, VT]() + entry.add_global_invalidation_callback(callback) + + for key in keys: + self._pending_deferred_cache[key] = entry + + return entry + + def _completed_callback( + self, value: VT, entry: "CacheEntry[KT, VT]", key: KT + ) -> VT: + """Called when a deferred is completed.""" + # We check if the current entry matches the entry associated with the + # deferred. If they don't match then it got invalidated. + current_entry = self._pending_deferred_cache.pop(key, None) + if current_entry is not entry: + if current_entry: + self._pending_deferred_cache[key] = current_entry + return value + + self.cache.set(key, value, entry.get_invalidation_callbacks(key)) + + return value + + def _error_callback( + self, + failure: Failure, + entry: "CacheEntry[KT, VT]", + key: KT, + ) -> Failure: + """Called when a deferred errors.""" + + # We check if the current entry matches the entry associated with the + # deferred. If they don't match then it got invalidated. + current_entry = self._pending_deferred_cache.pop(key, None) + if current_entry is not entry: + if current_entry: + self._pending_deferred_cache[key] = current_entry + return failure + + for cb in entry.get_invalidation_callbacks(key): + cb() + + return failure def prefill( self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None ) -> None: - callbacks = [callback] if callback else [] + callbacks = (callback,) if callback else () self.cache.set(key, value, callbacks=callbacks) + self._pending_deferred_cache.pop(key, None) def invalidate(self, key: KT) -> None: """Delete a key, or tree of entries @@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]): self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, which will (a) stop it being returned - # for future queries and (b) stop it being persisted as a proper entry + # _pending_deferred_cache, which will (a) stop it being returned for + # future queries and (b) stop it being persisted as a proper entry # in self.cache. entry = self._pending_deferred_cache.pop(key, None) - - # run the invalidation callbacks now, rather than waiting for the - # deferred to resolve. if entry: # _pending_deferred_cache.pop should either return a CacheEntry, or, in the # case of a TreeCache, a dict of keys to cache entries. Either way calling # iterate_tree_cache_entry on it will do the right thing. for entry in iterate_tree_cache_entry(entry): - entry.invalidate() + for cb in entry.get_invalidation_callbacks(key): + cb() def invalidate_all(self) -> None: self.check_thread() self.cache.clear() - for entry in self._pending_deferred_cache.values(): - entry.invalidate() + for key, entry in self._pending_deferred_cache.items(): + for cb in entry.get_invalidation_callbacks(key): + cb() + self._pending_deferred_cache.clear() -class CacheEntry: - __slots__ = ["deferred", "callbacks", "invalidated"] +class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta): + """Abstract class for entries in `DeferredCache[KT, VT]`""" - def __init__( - self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] - ): - self.deferred = deferred - self.callbacks = set(callbacks) - self.invalidated = False - - def invalidate(self) -> None: - if not self.invalidated: - self.invalidated = True - for callback in self.callbacks: - callback() - self.callbacks.clear() + @abc.abstractmethod + def deferred(self, key: KT) -> "defer.Deferred[VT]": + """Get a deferred that a caller can wait on to get the value at the + given key""" + ... + + @abc.abstractmethod + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: + """Add an invalidation callback""" + ... + + @abc.abstractmethod + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + """Get all invalidation callbacks""" + ... + + +class CacheEntrySingle(CacheEntry[KT, VT]): + """An implementation of `CacheEntry` wrapping a deferred that results in a + single cache entry. + """ + + __slots__ = ["_deferred", "_callbacks"] + + def __init__(self, deferred: "defer.Deferred[VT]") -> None: + self._deferred = ObservableDeferred(deferred, consumeErrors=True) + self._callbacks: Set[Callable[[], None]] = set() + + def deferred(self, key: KT) -> "defer.Deferred[VT]": + return self._deferred.observe() + + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: + if callback is None: + return + + self._callbacks.add(callback) + + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + return self._callbacks + + +class CacheMultipleEntries(CacheEntry[KT, VT]): + """Cache entry that is used for bulk lookups and insertions.""" + + __slots__ = ["_deferred", "_callbacks", "_global_callbacks"] + + def __init__(self) -> None: + self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None + self._callbacks: Dict[KT, Set[Callable[[], None]]] = {} + self._global_callbacks: Set[Callable[[], None]] = set() + + def deferred(self, key: KT) -> "defer.Deferred[VT]": + if not self._deferred: + self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) + return self._deferred.observe().addCallback(lambda res: res.get(key)) + + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: + if callback is None: + return + + self._callbacks.setdefault(key, set()).add(callback) + + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + return self._callbacks.get(key, set()) | self._global_callbacks + + def add_global_invalidation_callback( + self, callback: Optional[Callable[[], None]] + ) -> None: + """Add a callback for when any keys get invalidated.""" + if callback is None: + return + + self._global_callbacks.add(callback) + + def complete_bulk( + self, + cache: DeferredCache[KT, VT], + result: Dict[KT, VT], + ) -> None: + """Called when there is a result""" + for key, value in result.items(): + cache._completed_callback(value, self, key) + + if self._deferred: + self._deferred.callback(result) + + def error_bulk( + self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure + ) -> None: + """Called when bulk lookup failed.""" + for key in keys: + cache._error_callback(failure, self, key) + + if self._deferred: + self._deferred.errback(failure) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 867f315b2a..10aff4d04a 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -25,6 +25,7 @@ from typing import ( Generic, Hashable, Iterable, + List, Mapping, Optional, Sequence, @@ -73,8 +74,10 @@ class _CacheDescriptorBase: num_args: Optional[int], uncached_args: Optional[Collection[str]] = None, cache_context: bool = False, + name: Optional[str] = None, ): self.orig = orig + self.name = name or orig.__name__ arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args @@ -211,7 +214,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: LruCache[CacheKey, Any] = LruCache( - cache_name=self.orig.__name__, + cache_name=self.name, max_size=self.max_entries, ) @@ -241,7 +244,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): wrapped = cast(_CachedFunction, _wrapped) wrapped.cache = cache - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -301,12 +304,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, + name: Optional[str] = None, ): super().__init__( orig, num_args=num_args, uncached_args=uncached_args, cache_context=cache_context, + name=name, ) if tree and self.num_args < 2: @@ -321,7 +326,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: DeferredCache[CacheKey, Any] = DeferredCache( - name=self.orig.__name__, + name=self.name, max_entries=self.max_entries, tree=self.tree, iterable=self.iterable, @@ -372,7 +377,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): wrapped.cache = cache wrapped.num_args = self.num_args - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -393,6 +398,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): cached_method_name: str, list_name: str, num_args: Optional[int] = None, + name: Optional[str] = None, ): """ Args: @@ -403,7 +409,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): but including list_name) to use as cache keys. Defaults to all named args of the function. """ - super().__init__(orig, num_args=num_args, uncached_args=None) + super().__init__(orig, num_args=num_args, uncached_args=None, name=name) self.list_name = list_name @@ -435,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] - results = {} - - def update_results_dict(res: Any, arg: Hashable) -> None: - results[arg] = res - - # list of deferreds to wait for - cached_defers = [] - - missing = set() - # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: @@ -452,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): def arg_to_cache_key(arg: Hashable) -> Hashable: return arg + def cache_key_to_arg(key: tuple) -> Hashable: + return key + else: keylist = list(keyargs) @@ -459,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): keylist[self.list_pos] = arg return tuple(keylist) - for arg in list_args: - try: - res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) - if not res.called: - res.addCallback(update_results_dict, arg) - cached_defers.append(res) - else: - results[arg] = res.result - except KeyError: - missing.add(arg) + def cache_key_to_arg(key: tuple) -> Hashable: + return key[self.list_pos] + + cache_keys = [arg_to_cache_key(arg) for arg in list_args] + immediate_results, pending_deferred, missing = cache.get_bulk( + cache_keys, callback=invalidate_callback + ) + + results = {cache_key_to_arg(key): v for key, v in immediate_results.items()} + + cached_defers: List["defer.Deferred[Any]"] = [] + if pending_deferred: + + def update_results(r: Dict) -> None: + for k, v in r.items(): + results[cache_key_to_arg(k)] = v + + pending_deferred.addCallback(update_results) + cached_defers.append(pending_deferred) if missing: - # we need a deferred for each entry in the list, - # which we put in the cache. Each deferred resolves with the - # relevant result for that key. - deferreds_map = {} - for arg in missing: - deferred: "defer.Deferred[Any]" = defer.Deferred() - deferreds_map[arg] = deferred - key = arg_to_cache_key(arg) - cached_defers.append( - cache.set(key, deferred, callback=invalidate_callback) - ) + cache_entry = cache.start_bulk_input(missing, invalidate_callback) def complete_all(res: Dict[Hashable, Any]) -> None: - # the wrapped function has completed. It returns a dict. - # We can now update our own result map, and then resolve the - # observable deferreds in the cache. - for e, d1 in deferreds_map.items(): - val = res.get(e, None) - # make sure we update the results map before running the - # deferreds, because as soon as we run the last deferred, the - # gatherResults() below will complete and return the result - # dict to our caller. - results[e] = val - d1.callback(val) + missing_results = {} + for key in missing: + arg = cache_key_to_arg(key) + val = res.get(arg, None) + + results[arg] = val + missing_results[key] = val + + cache_entry.complete_bulk(cache, missing_results) def errback_all(f: Failure) -> None: - # the wrapped function has failed. Propagate the failure into - # the cache, which will invalidate the entry, and cause the - # relevant cached_deferreds to fail, which will propagate the - # failure to our caller. - for d1 in deferreds_map.values(): - d1.errback(f) + cache_entry.error_bulk(cache, missing, f) args_to_call = dict(arg_dict) - args_to_call[self.list_name] = missing + args_to_call[self.list_name] = { + cache_key_to_arg(key) for key in missing + } # dispatch the call, and attach the two handlers - defer.maybeDeferred( + missing_d = defer.maybeDeferred( preserve_fn(self.orig), **args_to_call ).addCallbacks(complete_all, errback_all) + cached_defers.append(missing_d) if cached_defers: d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( @@ -525,7 +519,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): else: return defer.succeed(results) - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -577,6 +571,7 @@ def cached( cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, + name: Optional[str] = None, ) -> Callable[[F], _CachedFunction[F]]: func = lambda orig: DeferredCacheDescriptor( orig, @@ -587,13 +582,18 @@ def cached( cache_context=cache_context, iterable=iterable, prune_unread_entries=prune_unread_entries, + name=name, ) return cast(Callable[[F], _CachedFunction[F]], func) def cachedList( - *, cached_method_name: str, list_name: str, num_args: Optional[int] = None + *, + cached_method_name: str, + list_name: str, + num_args: Optional[int] = None, + name: Optional[str] = None, ) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. @@ -628,6 +628,7 @@ def cachedList( cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, + name=name, ) return cast(Callable[[F], _CachedFunction[F]], func) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index c1b8ec0c73..fec31da2b6 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -135,6 +135,9 @@ class TreeCache: def values(self): return iterate_tree_cache_entry(self.root) + def items(self): + return iterate_tree_cache_items((), self.root) + def __len__(self) -> int: return self.size |