diff options
Diffstat (limited to 'synapse')
111 files changed, 2921 insertions, 1356 deletions
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index dae986c788..4ffe6a1ef3 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -1,5 +1,6 @@ # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,22 +20,23 @@ import hashlib import hmac import logging import sys +from typing import Callable, Optional import requests as _requests import yaml def request_registration( - user, - password, - server_location, - shared_secret, - admin=False, - user_type=None, + user: str, + password: str, + server_location: str, + shared_secret: str, + admin: bool = False, + user_type: Optional[str] = None, requests=_requests, - _print=print, - exit=sys.exit, -): + _print: Callable[[str], None] = print, + exit: Callable[[int], None] = sys.exit, +) -> None: url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),) @@ -65,13 +67,13 @@ def request_registration( mac.update(b"\x00") mac.update(user_type.encode("utf8")) - mac = mac.hexdigest() + hex_mac = mac.hexdigest() data = { "nonce": nonce, "username": user, "password": password, - "mac": mac, + "mac": hex_mac, "admin": admin, "user_type": user_type, } @@ -91,10 +93,17 @@ def request_registration( _print("Success!") -def register_new_user(user, password, server_location, shared_secret, admin, user_type): +def register_new_user( + user: str, + password: str, + server_location: str, + shared_secret: str, + admin: Optional[bool], + user_type: Optional[str], +) -> None: if not user: try: - default_user = getpass.getuser() + default_user: Optional[str] = getpass.getuser() except Exception: default_user = None @@ -123,8 +132,8 @@ def register_new_user(user, password, server_location, shared_secret, admin, use sys.exit(1) if admin is None: - admin = input("Make admin [no]: ") - if admin in ("y", "yes", "true"): + admin_inp = input("Make admin [no]: ") + if admin_inp in ("y", "yes", "true"): admin = True else: admin = False @@ -134,7 +143,7 @@ def register_new_user(user, password, server_location, shared_secret, admin, use ) -def main(): +def main() -> None: logging.captureWarnings(True) diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py index 8e66a38421..093af4327a 100644 --- a/synapse/_scripts/review_recent_signups.py +++ b/synapse/_scripts/review_recent_signups.py @@ -92,7 +92,7 @@ def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]: return user_infos -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( "-c", @@ -142,7 +142,8 @@ def main(): engine = create_engine(database_config.config) with make_conn(database_config, engine, "review_recent_signups") as db_conn: - user_infos = get_recent_users(db_conn.cursor(), since_ms) + # This generates a type of Cursor, not LoggingTransaction. + user_infos = get_recent_users(db_conn.cursor(), since_ms) # type: ignore[arg-type] for user_info in user_infos: if exclude_users_with_email and user_info.emails: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 4b0a9b2974..13dd6ce248 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -1,7 +1,7 @@ # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -86,6 +86,9 @@ ROOM_EVENT_FILTER_SCHEMA = { # cf https://github.com/matrix-org/matrix-doc/pull/2326 "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, + # MSC3440, filtering by event relations. + "io.element.relation_senders": {"type": "array", "items": {"type": "string"}}, + "io.element.relation_types": {"type": "array", "items": {"type": "string"}}, }, } @@ -146,14 +149,16 @@ def matrix_user_id_validator(user_id_str: str) -> UserID: class Filtering: def __init__(self, hs: "HomeServer"): - super().__init__() + self._hs = hs self.store = hs.get_datastore() + self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) + async def get_user_filter( self, user_localpart: str, filter_id: Union[int, str] ) -> "FilterCollection": result = await self.store.get_user_filter(user_localpart, filter_id) - return FilterCollection(result) + return FilterCollection(self._hs, result) def add_user_filter( self, user_localpart: str, user_filter: JsonDict @@ -191,21 +196,22 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict) class FilterCollection: - def __init__(self, filter_json: JsonDict): + def __init__(self, hs: "HomeServer", filter_json: JsonDict): self._filter_json = filter_json room_filter_json = self._filter_json.get("room", {}) self._room_filter = Filter( - {k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")} + hs, + {k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}, ) - self._room_timeline_filter = Filter(room_filter_json.get("timeline", {})) - self._room_state_filter = Filter(room_filter_json.get("state", {})) - self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {})) - self._room_account_data = Filter(room_filter_json.get("account_data", {})) - self._presence_filter = Filter(filter_json.get("presence", {})) - self._account_data = Filter(filter_json.get("account_data", {})) + self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {})) + self._room_state_filter = Filter(hs, room_filter_json.get("state", {})) + self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {})) + self._room_account_data = Filter(hs, room_filter_json.get("account_data", {})) + self._presence_filter = Filter(hs, filter_json.get("presence", {})) + self._account_data = Filter(hs, filter_json.get("account_data", {})) self.include_leave = filter_json.get("room", {}).get("include_leave", False) self.event_fields = filter_json.get("event_fields", []) @@ -232,25 +238,37 @@ class FilterCollection: def include_redundant_members(self) -> bool: return self._room_state_filter.include_redundant_members - def filter_presence( + async def filter_presence( self, events: Iterable[UserPresenceState] ) -> List[UserPresenceState]: - return self._presence_filter.filter(events) + return await self._presence_filter.filter(events) - def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: - return self._account_data.filter(events) + async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: + return await self._account_data.filter(events) - def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]: - return self._room_state_filter.filter(self._room_filter.filter(events)) + async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]: + return await self._room_state_filter.filter( + await self._room_filter.filter(events) + ) - def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]: - return self._room_timeline_filter.filter(self._room_filter.filter(events)) + async def filter_room_timeline( + self, events: Iterable[EventBase] + ) -> List[EventBase]: + return await self._room_timeline_filter.filter( + await self._room_filter.filter(events) + ) - def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]: - return self._room_ephemeral_filter.filter(self._room_filter.filter(events)) + async def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]: + return await self._room_ephemeral_filter.filter( + await self._room_filter.filter(events) + ) - def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: - return self._room_account_data.filter(self._room_filter.filter(events)) + async def filter_room_account_data( + self, events: Iterable[JsonDict] + ) -> List[JsonDict]: + return await self._room_account_data.filter( + await self._room_filter.filter(events) + ) def blocks_all_presence(self) -> bool: return ( @@ -274,7 +292,9 @@ class FilterCollection: class Filter: - def __init__(self, filter_json: JsonDict): + def __init__(self, hs: "HomeServer", filter_json: JsonDict): + self._hs = hs + self._store = hs.get_datastore() self.filter_json = filter_json self.limit = filter_json.get("limit", 10) @@ -297,6 +317,20 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) + # Ideally these would be rejected at the endpoint if they were provided + # and not supported, but that would involve modifying the JSON schema + # based on the homeserver configuration. + if hs.config.experimental.msc3440_enabled: + self.relation_senders = self.filter_json.get( + "io.element.relation_senders", None + ) + self.relation_types = self.filter_json.get( + "io.element.relation_types", None + ) + else: + self.relation_senders = None + self.relation_types = None + def filters_all_types(self) -> bool: return "*" in self.not_types @@ -306,7 +340,7 @@ class Filter: def filters_all_rooms(self) -> bool: return "*" in self.not_rooms - def check(self, event: FilterEvent) -> bool: + def _check(self, event: FilterEvent) -> bool: """Checks whether the filter matches the given event. Args: @@ -420,8 +454,30 @@ class Filter: return room_ids - def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: - return list(filter(self.check, events)) + async def _check_event_relations( + self, events: Iterable[FilterEvent] + ) -> List[FilterEvent]: + # The event IDs to check, mypy doesn't understand the ifinstance check. + event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined] + event_ids_to_keep = set( + await self._store.events_have_relations( + event_ids, self.relation_senders, self.relation_types + ) + ) + + return [ + event + for event in events + if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep + ] + + async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: + result = [event for event in events if self._check(event)] + + if self.relation_senders or self.relation_types: + return await self._check_event_relations(result) + + return result def with_room_ids(self, room_ids: Iterable[str]) -> "Filter": """Returns a new filter with the given room IDs appended. @@ -433,7 +489,7 @@ class Filter: filter: A new filter including the given rooms and the old filter's rooms. """ - newFilter = Filter(self.filter_json) + newFilter = Filter(self._hs, self.filter_json) newFilter.rooms += room_ids return newFilter @@ -444,6 +500,3 @@ def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool: return actual_value.startswith(type_prefix) else: return actual_value == filter_value - - -DEFAULT_FILTER_COLLECTION = FilterCollection({}) diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 4486b3bc7d..f9f9467dc1 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -30,7 +30,8 @@ FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" WEB_CLIENT_PREFIX = "/_matrix/client" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" -MEDIA_PREFIX = "/_matrix/media/r0" +MEDIA_R0_PREFIX = "/_matrix/media/r0" +MEDIA_V3_PREFIX = "/_matrix/media/v3" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index f9940491e8..ee51480a9e 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import sys +from typing import Container from synapse import python_dependencies # noqa: E402 @@ -27,7 +28,9 @@ except python_dependencies.DependencyException as e: sys.exit(1) -def check_bind_error(e, address, bind_addresses): +def check_bind_error( + e: Exception, address: str, bind_addresses: Container[str] +) -> None: """ This method checks an exception occurred while binding on 0.0.0.0. If :: is specified in the bind addresses a warning is shown. @@ -38,9 +41,9 @@ def check_bind_error(e, address, bind_addresses): When binding on 0.0.0.0 after :: this can safely be ignored. Args: - e (Exception): Exception that was caught. - address (str): Address on which binding was attempted. - bind_addresses (list): Addresses on which the service listens. + e: Exception that was caught. + address: Address on which binding was attempted. + bind_addresses: Addresses on which the service listens. """ if address == "0.0.0.0" and "::" in bind_addresses: logger.warning( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index f2c1028b5d..807ee3d46e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -22,13 +22,27 @@ import socket import sys import traceback import warnings -from typing import TYPE_CHECKING, Awaitable, Callable, Iterable +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Collection, + Dict, + Iterable, + List, + NoReturn, + Tuple, + cast, +) from cryptography.utils import CryptographyDeprecationWarning -from typing_extensions import NoReturn import twisted -from twisted.internet import defer, error, reactor +from twisted.internet import defer, error, reactor as _reactor +from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorSSL, IReactorTCP +from twisted.internet.protocol import ServerFactory +from twisted.internet.tcp import Port from twisted.logger import LoggingFile, LogLevel from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.python.threadpool import ThreadPool @@ -48,6 +62,7 @@ from synapse.logging.context import PreserveLoggingContext from synapse.metrics import register_threadpool from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats +from synapse.types import ISynapseReactor from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.daemonize import daemonize_process from synapse.util.gai_resolver import GAIResolver @@ -57,33 +72,44 @@ from synapse.util.versionstring import get_version_string if TYPE_CHECKING: from synapse.server import HomeServer +# Twisted injects the global reactor to make it easier to import, this confuses +# mypy which thinks it is a module. Tell it that it a more proper type. +reactor = cast(ISynapseReactor, _reactor) + + logger = logging.getLogger(__name__) # list of tuples of function, args list, kwargs dict -_sighup_callbacks = [] +_sighup_callbacks: List[ + Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] +] = [] -def register_sighup(func, *args, **kwargs): +def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None: """ Register a function to be called when a SIGHUP occurs. Args: - func (function): Function to be called when sent a SIGHUP signal. + func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ _sighup_callbacks.append((func, args, kwargs)) -def start_worker_reactor(appname, config, run_command=reactor.run): +def start_worker_reactor( + appname: str, + config: HomeServerConfig, + run_command: Callable[[], None] = reactor.run, +) -> None: """Run the reactor in the main process Daemonizes if necessary, and then configures some resources, before starting the reactor. Pulls configuration from the 'worker' settings in 'config'. Args: - appname (str): application name which will be sent to syslog - config (synapse.config.Config): config object - run_command (Callable[]): callable that actually runs the reactor + appname: application name which will be sent to syslog + config: config object + run_command: callable that actually runs the reactor """ logger = logging.getLogger(config.worker.worker_app) @@ -101,32 +127,32 @@ def start_worker_reactor(appname, config, run_command=reactor.run): def start_reactor( - appname, - soft_file_limit, - gc_thresholds, - pid_file, - daemonize, - print_pidfile, - logger, - run_command=reactor.run, -): + appname: str, + soft_file_limit: int, + gc_thresholds: Tuple[int, int, int], + pid_file: str, + daemonize: bool, + print_pidfile: bool, + logger: logging.Logger, + run_command: Callable[[], None] = reactor.run, +) -> None: """Run the reactor in the main process Daemonizes if necessary, and then configures some resources, before starting the reactor Args: - appname (str): application name which will be sent to syslog - soft_file_limit (int): + appname: application name which will be sent to syslog + soft_file_limit: gc_thresholds: - pid_file (str): name of pid file to write to if daemonize is True - daemonize (bool): true to run the reactor in a background process - print_pidfile (bool): whether to print the pid file, if daemonize is True - logger (logging.Logger): logger instance to pass to Daemonize - run_command (Callable[]): callable that actually runs the reactor + pid_file: name of pid file to write to if daemonize is True + daemonize: true to run the reactor in a background process + print_pidfile: whether to print the pid file, if daemonize is True + logger: logger instance to pass to Daemonize + run_command: callable that actually runs the reactor """ - def run(): + def run() -> None: logger.info("Running") setup_jemalloc_stats() change_resource_limit(soft_file_limit) @@ -185,7 +211,7 @@ def redirect_stdio_to_logs() -> None: print("Redirected stdout/stderr to logs") -def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None: +def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None: """Register a callback with the reactor, to be called once it is running This can be used to initialise parts of the system which require an asynchronous @@ -195,7 +221,7 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None: will exit. """ - async def wrapper(): + async def wrapper() -> None: try: await cb(*args, **kwargs) except Exception: @@ -224,7 +250,7 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None: reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics(bind_addresses, port): +def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: """ Start Prometheus metrics server. """ @@ -236,11 +262,11 @@ def listen_metrics(bind_addresses, port): def listen_manhole( - bind_addresses: Iterable[str], + bind_addresses: Collection[str], port: int, manhole_settings: ManholeConfig, manhole_globals: dict, -): +) -> None: # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so # suppress the warning for now. @@ -259,12 +285,18 @@ def listen_manhole( ) -def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): +def listen_tcp( + bind_addresses: Collection[str], + port: int, + factory: ServerFactory, + reactor: IReactorTCP = reactor, + backlog: int = 50, +) -> List[Port]: """ Create a TCP socket for a port and several addresses Returns: - list[twisted.internet.tcp.Port]: listening for TCP connections + list of twisted.internet.tcp.Port listening for TCP connections """ r = [] for address in bind_addresses: @@ -273,12 +305,19 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) - return r + # IReactorTCP returns an object implementing IListeningPort from listenTCP, + # but we know it will be a Port instance. + return r # type: ignore[return-value] def listen_ssl( - bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50 -): + bind_addresses: Collection[str], + port: int, + factory: ServerFactory, + context_factory: IOpenSSLContextFactory, + reactor: IReactorSSL = reactor, + backlog: int = 50, +) -> List[Port]: """ Create an TLS-over-TCP socket for a port and several addresses @@ -294,10 +333,13 @@ def listen_ssl( except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) - return r + # IReactorSSL incorrectly declares that an int is returned from listenSSL, + # it actually returns an object implementing IListeningPort, but we know it + # will be a Port instance. + return r # type: ignore[return-value] -def refresh_certificate(hs: "HomeServer"): +def refresh_certificate(hs: "HomeServer") -> None: """ Refresh the TLS certificates that Synapse is using by re-reading them from disk and updating the TLS context factories to use them. @@ -329,7 +371,7 @@ def refresh_certificate(hs: "HomeServer"): logger.info("Context factories updated.") -async def start(hs: "HomeServer"): +async def start(hs: "HomeServer") -> None: """ Start a Synapse server or worker. @@ -360,7 +402,7 @@ async def start(hs: "HomeServer"): if hasattr(signal, "SIGHUP"): @wrap_as_background_process("sighup") - def handle_sighup(*args, **kwargs): + async def handle_sighup(*args: Any, **kwargs: Any) -> None: # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. sdnotify(b"RELOADING=1") @@ -373,7 +415,7 @@ async def start(hs: "HomeServer"): # We defer running the sighup handlers until next reactor tick. This # is so that we're in a sane state, e.g. flushing the logs may fail # if the sighup happens in the middle of writing a log entry. - def run_sighup(*args, **kwargs): + def run_sighup(*args: Any, **kwargs: Any) -> None: # `callFromThread` should be "signal safe" as well as thread # safe. reactor.callFromThread(handle_sighup, *args, **kwargs) @@ -436,12 +478,8 @@ async def start(hs: "HomeServer"): atexit.register(gc.freeze) -def setup_sentry(hs: "HomeServer"): - """Enable sentry integration, if enabled in configuration - - Args: - hs - """ +def setup_sentry(hs: "HomeServer") -> None: + """Enable sentry integration, if enabled in configuration""" if not hs.config.metrics.sentry_enabled: return @@ -466,7 +504,7 @@ def setup_sentry(hs: "HomeServer"): scope.set_tag("worker_name", name) -def setup_sdnotify(hs: "HomeServer"): +def setup_sdnotify(hs: "HomeServer") -> None: """Adds process state hooks to tell systemd what we are up to.""" # Tell systemd our state, if we're using it. This will silently fail if @@ -481,7 +519,7 @@ def setup_sdnotify(hs: "HomeServer"): sdnotify_sockaddr = os.getenv("NOTIFY_SOCKET") -def sdnotify(state): +def sdnotify(state: bytes) -> None: """ Send a notification to systemd, if the NOTIFY_SOCKET env var is set. @@ -490,7 +528,7 @@ def sdnotify(state): package which many OSes don't include as a matter of principle. Args: - state (bytes): notification to send + state: notification to send """ if not isinstance(state, bytes): raise TypeError("sdnotify should be called with a bytes") diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index ad20b1d6aa..42238f7f28 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -17,6 +17,7 @@ import logging import os import sys import tempfile +from typing import List, Optional from twisted.internet import defer, task @@ -25,6 +26,7 @@ from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging +from synapse.events import EventBase from synapse.handlers.admin import ExfiltrationWriter from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore @@ -40,6 +42,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.server import HomeServer from synapse.storage.databases.main.room import RoomWorkerStore +from synapse.types import StateMap from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -65,16 +68,11 @@ class AdminCmdSlavedStore( class AdminCmdServer(HomeServer): - DATASTORE_CLASS = AdminCmdSlavedStore + DATASTORE_CLASS = AdminCmdSlavedStore # type: ignore -async def export_data_command(hs: HomeServer, args): - """Export data for a user. - - Args: - hs - args (argparse.Namespace) - """ +async def export_data_command(hs: HomeServer, args: argparse.Namespace) -> None: + """Export data for a user.""" user_id = args.user_id directory = args.output_directory @@ -92,12 +90,12 @@ class FileExfiltrationWriter(ExfiltrationWriter): Note: This writes to disk on the main reactor thread. Args: - user_id (str): The user whose data is being exfiltrated. - directory (str|None): The directory to write the data to, if None then - will write to a temporary directory. + user_id: The user whose data is being exfiltrated. + directory: The directory to write the data to, if None then will write + to a temporary directory. """ - def __init__(self, user_id, directory=None): + def __init__(self, user_id: str, directory: Optional[str] = None): self.user_id = user_id if directory: @@ -111,7 +109,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): if list(os.listdir(self.base_directory)): raise Exception("Directory must be empty") - def write_events(self, room_id, events): + def write_events(self, room_id: str, events: List[EventBase]) -> None: room_directory = os.path.join(self.base_directory, "rooms", room_id) os.makedirs(room_directory, exist_ok=True) events_file = os.path.join(room_directory, "events") @@ -120,7 +118,9 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in events: print(json.dumps(event.get_pdu_json()), file=f) - def write_state(self, room_id, event_id, state): + def write_state( + self, room_id: str, event_id: str, state: StateMap[EventBase] + ) -> None: room_directory = os.path.join(self.base_directory, "rooms", room_id) state_directory = os.path.join(room_directory, "state") os.makedirs(state_directory, exist_ok=True) @@ -131,7 +131,9 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): print(json.dumps(event.get_pdu_json()), file=f) - def write_invite(self, room_id, event, state): + def write_invite( + self, room_id: str, event: EventBase, state: StateMap[EventBase] + ) -> None: self.write_events(room_id, [event]) # We write the invite state somewhere else as they aren't full events @@ -145,7 +147,9 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): print(json.dumps(event), file=f) - def write_knock(self, room_id, event, state): + def write_knock( + self, room_id: str, event: EventBase, state: StateMap[EventBase] + ) -> None: self.write_events(room_id, [event]) # We write the knock state somewhere else as they aren't full events @@ -159,11 +163,11 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): print(json.dumps(event), file=f) - def finished(self): + def finished(self) -> str: return self.base_directory -def start(config_options): +def start(config_options: List[str]) -> None: parser = argparse.ArgumentParser(description="Synapse Admin Command") HomeServerConfig.add_arguments_to_parser(parser) @@ -231,7 +235,7 @@ def start(config_options): # We also make sure that `_base.start` gets run before we actually run the # command. - async def run(): + async def run() -> None: with LoggingContext("command"): await _base.start(ss) await args.func(ss, args) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 218826741e..502cc8e8d1 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -14,11 +14,10 @@ # limitations under the License. import logging import sys -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple from twisted.internet import address -from twisted.web.resource import IResource -from twisted.web.server import Request +from twisted.web.resource import Resource import synapse import synapse.events @@ -27,7 +26,8 @@ from synapse.api.urls import ( CLIENT_API_PREFIX, FEDERATION_PREFIX, LEGACY_MEDIA_PREFIX, - MEDIA_PREFIX, + MEDIA_R0_PREFIX, + MEDIA_V3_PREFIX, SERVER_KEY_V2_PREFIX, ) from synapse.app import _base @@ -44,7 +44,7 @@ from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource, OptionsResource from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.http.site import SynapseSite +from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource @@ -119,6 +119,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore +from synapse.types import JsonDict from synapse.util.httpresourcetree import create_resource_tree from synapse.util.versionstring import get_version_string @@ -143,7 +144,9 @@ class KeyUploadServlet(RestServlet): self.http_client = hs.get_simple_http_client() self.main_uri = hs.config.worker.worker_main_http_uri - async def on_POST(self, request: Request, device_id: Optional[str]): + async def on_POST( + self, request: SynapseRequest, device_id: Optional[str] + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -187,9 +190,8 @@ class KeyUploadServlet(RestServlet): # If the header exists, add to the comma-separated list of the first # instance of the header. Otherwise, generate a new header. if x_forwarded_for: - x_forwarded_for = [ - x_forwarded_for[0] + b", " + previous_host - ] + x_forwarded_for[1:] + x_forwarded_for = [x_forwarded_for[0] + b", " + previous_host] + x_forwarded_for.extend(x_forwarded_for[1:]) else: x_forwarded_for = [previous_host] headers[b"X-Forwarded-For"] = x_forwarded_for @@ -253,13 +255,16 @@ class GenericWorkerSlavedStore( SessionStore, BaseSlavedStore, ): - pass + # Properties that multiple storage classes define. Tell mypy what the + # expected type is. + server_name: str + config: HomeServerConfig class GenericWorkerServer(HomeServer): - DATASTORE_CLASS = GenericWorkerSlavedStore + DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore - def _listen_http(self, listener_config: ListenerConfig): + def _listen_http(self, listener_config: ListenerConfig) -> None: port = listener_config.port bind_addresses = listener_config.bind_addresses @@ -267,10 +272,10 @@ class GenericWorkerServer(HomeServer): site_tag = listener_config.http_options.tag if site_tag is None: - site_tag = port + site_tag = str(port) # We always include a health resource. - resources: Dict[str, IResource] = {"/health": HealthResource()} + resources: Dict[str, Resource] = {"/health": HealthResource()} for res in listener_config.http_options.resources: for name in res.names: @@ -334,7 +339,8 @@ class GenericWorkerServer(HomeServer): resources.update( { - MEDIA_PREFIX: media_repo, + MEDIA_R0_PREFIX: media_repo, + MEDIA_V3_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo, "/_synapse/admin": admin_resource, } @@ -386,7 +392,7 @@ class GenericWorkerServer(HomeServer): logger.info("Synapse worker now listening on port %d", port) - def start_listening(self): + def start_listening(self) -> None: for listener in self.config.worker.worker_listeners: if listener.type == "http": self._listen_http(listener) @@ -411,7 +417,7 @@ class GenericWorkerServer(HomeServer): self.get_tcp_replication().start_replication(self) -def start(config_options): +def start(config_options: List[str]) -> None: try: config = HomeServerConfig.load_config("Synapse worker", config_options) except ConfigError as e: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 336c279a44..7e09530ad2 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -16,10 +16,10 @@ import logging import os import sys -from typing import Iterator +from typing import Dict, Iterable, Iterator, List -from twisted.internet import reactor -from twisted.web.resource import EncodingResourceWrapper, IResource +from twisted.internet.tcp import Port +from twisted.web.resource import EncodingResourceWrapper, Resource from twisted.web.server import GzipEncoderFactory from twisted.web.static import File @@ -29,7 +29,8 @@ from synapse import events from synapse.api.urls import ( FEDERATION_PREFIX, LEGACY_MEDIA_PREFIX, - MEDIA_PREFIX, + MEDIA_R0_PREFIX, + MEDIA_V3_PREFIX, SERVER_KEY_V2_PREFIX, STATIC_PREFIX, WEB_CLIENT_PREFIX, @@ -76,23 +77,27 @@ from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.homeserver") -def gz_wrap(r): +def gz_wrap(r: Resource) -> Resource: return EncodingResourceWrapper(r, [GzipEncoderFactory()]) class SynapseHomeServer(HomeServer): - DATASTORE_CLASS = DataStore + DATASTORE_CLASS = DataStore # type: ignore - def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig): + def _listener_http( + self, config: HomeServerConfig, listener_config: ListenerConfig + ) -> Iterable[Port]: port = listener_config.port bind_addresses = listener_config.bind_addresses tls = listener_config.tls + # Must exist since this is an HTTP listener. + assert listener_config.http_options is not None site_tag = listener_config.http_options.tag if site_tag is None: site_tag = str(port) # We always include a health resource. - resources = {"/health": HealthResource()} + resources: Dict[str, Resource] = {"/health": HealthResource()} for res in listener_config.http_options.resources: for name in res.names: @@ -111,7 +116,7 @@ class SynapseHomeServer(HomeServer): ("listeners", site_tag, "additional_resources", "<%s>" % (path,)), ) handler = handler_cls(config, module_api) - if IResource.providedBy(handler): + if isinstance(handler, Resource): resource = handler elif hasattr(handler, "handle_request"): resource = AdditionalResource(self, handler.handle_request) @@ -128,7 +133,7 @@ class SynapseHomeServer(HomeServer): # try to find something useful to redirect '/' to if WEB_CLIENT_PREFIX in resources: - root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX) + root_resource: Resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX) elif STATIC_PREFIX in resources: root_resource = RootOptionsRedirectResource(STATIC_PREFIX) else: @@ -145,6 +150,8 @@ class SynapseHomeServer(HomeServer): ) if tls: + # refresh_certificate should have been called before this. + assert self.tls_server_context_factory is not None ports = listen_ssl( bind_addresses, port, @@ -165,20 +172,21 @@ class SynapseHomeServer(HomeServer): return ports - def _configure_named_resource(self, name, compress=False): + def _configure_named_resource( + self, name: str, compress: bool = False + ) -> Dict[str, Resource]: """Build a resource map for a named resource Args: - name (str): named resource: one of "client", "federation", etc - compress (bool): whether to enable gzip compression for this - resource + name: named resource: one of "client", "federation", etc + compress: whether to enable gzip compression for this resource Returns: - dict[str, Resource]: map from path to HTTP resource + map from path to HTTP resource """ - resources = {} + resources: Dict[str, Resource] = {} if name == "client": - client_resource = ClientRestResource(self) + client_resource: Resource = ClientRestResource(self) if compress: client_resource = gz_wrap(client_resource) @@ -186,6 +194,7 @@ class SynapseHomeServer(HomeServer): { "/_matrix/client/api/v1": client_resource, "/_matrix/client/r0": client_resource, + "/_matrix/client/v3": client_resource, "/_matrix/client/unstable": client_resource, "/_matrix/client/v2_alpha": client_resource, "/_matrix/client/versions": client_resource, @@ -207,7 +216,7 @@ class SynapseHomeServer(HomeServer): if name == "consent": from synapse.rest.consent.consent_resource import ConsentResource - consent_resource = ConsentResource(self) + consent_resource: Resource = ConsentResource(self) if compress: consent_resource = gz_wrap(consent_resource) resources.update({"/_matrix/consent": consent_resource}) @@ -237,7 +246,11 @@ class SynapseHomeServer(HomeServer): if self.config.server.enable_media_repo: media_repo = self.get_media_repository_resource() resources.update( - {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo} + { + MEDIA_R0_PREFIX: media_repo, + MEDIA_V3_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + } ) elif name == "media": raise ConfigError( @@ -277,7 +290,7 @@ class SynapseHomeServer(HomeServer): return resources - def start_listening(self): + def start_listening(self) -> None: if self.config.redis.redis_enabled: # If redis is enabled we connect via the replication command handler # in the same way as the workers (since we're effectively a client @@ -303,7 +316,9 @@ class SynapseHomeServer(HomeServer): ReplicationStreamProtocolFactory(self), ) for s in services: - reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) + self.get_reactor().addSystemEventTrigger( + "before", "shutdown", s.stopListening + ) elif listener.type == "metrics": if not self.config.metrics.enable_metrics: logger.warning( @@ -318,14 +333,13 @@ class SynapseHomeServer(HomeServer): logger.warning("Unrecognized listener type: %s", listener.type) -def setup(config_options): +def setup(config_options: List[str]) -> SynapseHomeServer: """ Args: - config_options_options: The options passed to Synapse. Usually - `sys.argv[1:]`. + config_options_options: The options passed to Synapse. Usually `sys.argv[1:]`. Returns: - HomeServer + A homeserver instance. """ try: config = HomeServerConfig.load_or_generate_config( @@ -364,7 +378,7 @@ def setup(config_options): except Exception as e: handle_startup_exception(e) - async def start(): + async def start() -> None: # Load the OIDC provider metadatas, if OIDC is enabled. if hs.config.oidc.oidc_enabled: oidc = hs.get_oidc_handler() @@ -404,39 +418,15 @@ def format_config_error(e: ConfigError) -> Iterator[str]: yield ":\n %s" % (e.msg,) - e = e.__cause__ + parent_e = e.__cause__ indent = 1 - while e: + while parent_e: indent += 1 - yield ":\n%s%s" % (" " * indent, str(e)) - e = e.__cause__ - - -def run(hs: HomeServer): - PROFILE_SYNAPSE = False - if PROFILE_SYNAPSE: - - def profile(func): - from cProfile import Profile - from threading import current_thread - - def profiled(*args, **kargs): - profile = Profile() - profile.enable() - func(*args, **kargs) - profile.disable() - ident = current_thread().ident - profile.dump_stats( - "/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident) - ) - - return profiled - - from twisted.python.threadpool import ThreadPool + yield ":\n%s%s" % (" " * indent, str(parent_e)) + parent_e = parent_e.__cause__ - ThreadPool._worker = profile(ThreadPool._worker) - reactor.run = profile(reactor.run) +def run(hs: HomeServer) -> None: _base.start_reactor( "synapse-homeserver", soft_file_limit=hs.config.server.soft_file_limit, @@ -448,7 +438,7 @@ def run(hs: HomeServer): ) -def main(): +def main() -> None: with LoggingContext("main"): # check base requirements check_requirements() diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 126450e17a..899dba5c3d 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -15,11 +15,12 @@ import logging import math import resource import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Sized, Tuple from prometheus_client import Gauge from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -28,7 +29,7 @@ logger = logging.getLogger("synapse.app.homeserver") # Contains the list of processes we will be monitoring # currently either 0 or 1 -_stats_process = [] +_stats_process: List[Tuple[int, "resource.struct_rusage"]] = [] # Gauges to expose monthly active user control metrics current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") @@ -45,9 +46,15 @@ registered_reserved_users_mau_gauge = Gauge( @wrap_as_background_process("phone_stats_home") -async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process): +async def phone_stats_home( + hs: "HomeServer", + stats: JsonDict, + stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process, +) -> None: logger.info("Gathering stats for reporting") now = int(hs.get_clock().time()) + # Ensure the homeserver has started. + assert hs.start_time is not None uptime = int(now - hs.start_time) if uptime < 0: uptime = 0 @@ -146,15 +153,15 @@ async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process logger.warning("Error reporting stats: %s", e) -def start_phone_stats_home(hs: "HomeServer"): +def start_phone_stats_home(hs: "HomeServer") -> None: """ Start the background tasks which report phone home stats. """ clock = hs.get_clock() - stats = {} + stats: JsonDict = {} - def performance_stats_init(): + def performance_stats_init() -> None: _stats_process.clear() _stats_process.append( (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) @@ -170,10 +177,10 @@ def start_phone_stats_home(hs: "HomeServer"): hs.get_datastore().reap_monthly_active_users() @wrap_as_background_process("generate_monthly_active_users") - async def generate_monthly_active_users(): + async def generate_monthly_active_users() -> None: current_mau_count = 0 current_mau_count_by_service = {} - reserved_users = () + reserved_users: Sized = () store = hs.get_datastore() if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d08f6bbd7f..f51b636417 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -231,13 +231,32 @@ class ApplicationServiceApi(SimpleHttpClient): json_body=body, args={"access_token": service.hs_token}, ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "push_bulk to %s succeeded! events=%s", + uri, + [event.get("event_id") for event in events], + ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) return True except CodeMessageException as e: - logger.warning("push_bulk to %s received %s", uri, e.code) + logger.warning( + "push_bulk to %s received code=%s msg=%s", + uri, + e.code, + e.msg, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) except Exception as ex: - logger.warning("push_bulk to %s threw exception %s", uri, ex) + logger.warning( + "push_bulk to %s threw exception(%s) %s args=%s", + uri, + type(ex).__name__, + ex, + ex.args, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) failed_transactions_counter.labels(service.id).inc() return False diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 7c4428a138..1265738dc1 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -20,7 +20,18 @@ import os from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Iterable, List, MutableMapping, Optional, Union +from typing import ( + Any, + Dict, + Iterable, + List, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import attr import jinja2 @@ -78,7 +89,7 @@ CONFIG_FILE_HEADER = """\ """ -def path_exists(file_path): +def path_exists(file_path: str) -> bool: """Check if a file exists Unlike os.path.exists, this throws an exception if there is an error @@ -86,7 +97,7 @@ def path_exists(file_path): the parent dir). Returns: - bool: True if the file exists; False if not. + True if the file exists; False if not. """ try: os.stat(file_path) @@ -102,15 +113,15 @@ class Config: A configuration section, containing configuration keys and values. Attributes: - section (str): The section title of this config object, such as + section: The section title of this config object, such as "tls" or "logger". This is used to refer to it on the root logger (for example, `config.tls.some_option`). Must be defined in subclasses. """ - section = None + section: str - def __init__(self, root_config=None): + def __init__(self, root_config: "RootConfig" = None): self.root = root_config # Get the path to the default Synapse template directory @@ -119,7 +130,7 @@ class Config: ) @staticmethod - def parse_size(value): + def parse_size(value: Union[str, int]) -> int: if isinstance(value, int): return value sizes = {"K": 1024, "M": 1024 * 1024} @@ -162,15 +173,15 @@ class Config: return int(value) * size @staticmethod - def abspath(file_path): + def abspath(file_path: str) -> str: return os.path.abspath(file_path) if file_path else file_path @classmethod - def path_exists(cls, file_path): + def path_exists(cls, file_path: str) -> bool: return path_exists(file_path) @classmethod - def check_file(cls, file_path, config_name): + def check_file(cls, file_path: Optional[str], config_name: str) -> str: if file_path is None: raise ConfigError("Missing config for %s." % (config_name,)) try: @@ -183,7 +194,7 @@ class Config: return cls.abspath(file_path) @classmethod - def ensure_directory(cls, dir_path): + def ensure_directory(cls, dir_path: str) -> str: dir_path = cls.abspath(dir_path) os.makedirs(dir_path, exist_ok=True) if not os.path.isdir(dir_path): @@ -191,7 +202,7 @@ class Config: return dir_path @classmethod - def read_file(cls, file_path, config_name): + def read_file(cls, file_path: Any, config_name: str) -> str: """Deprecated: call read_file directly""" return read_file(file_path, (config_name,)) @@ -284,6 +295,9 @@ class Config: return [env.get_template(filename) for filename in filenames] +TRootConfig = TypeVar("TRootConfig", bound="RootConfig") + + class RootConfig: """ Holder of an application's configuration. @@ -308,7 +322,9 @@ class RootConfig: raise Exception("Failed making %s: %r" % (config_class.section, e)) setattr(self, config_class.section, conf) - def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]: + def invoke_all( + self, func_name: str, *args: Any, **kwargs: Any + ) -> MutableMapping[str, Any]: """ Invoke a function on all instantiated config objects this RootConfig is configured to use. @@ -317,6 +333,7 @@ class RootConfig: func_name: Name of function to invoke *args **kwargs + Returns: ordered dictionary of config section name and the result of the function from it. @@ -332,7 +349,7 @@ class RootConfig: return res @classmethod - def invoke_all_static(cls, func_name: str, *args, **kwargs): + def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: any) -> None: """ Invoke a static function on config objects this RootConfig is configured to use. @@ -341,6 +358,7 @@ class RootConfig: func_name: Name of function to invoke *args **kwargs + Returns: ordered dictionary of config section name and the result of the function from it. @@ -351,16 +369,16 @@ class RootConfig: def generate_config( self, - config_dir_path, - data_dir_path, - server_name, - generate_secrets=False, - report_stats=None, - open_private_ports=False, - listeners=None, - tls_certificate_path=None, - tls_private_key_path=None, - ): + config_dir_path: str, + data_dir_path: str, + server_name: str, + generate_secrets: bool = False, + report_stats: Optional[bool] = None, + open_private_ports: bool = False, + listeners: Optional[List[dict]] = None, + tls_certificate_path: Optional[str] = None, + tls_private_key_path: Optional[str] = None, + ) -> str: """ Build a default configuration file @@ -368,27 +386,27 @@ class RootConfig: (eg with --generate_config). Args: - config_dir_path (str): The path where the config files are kept. Used to + config_dir_path: The path where the config files are kept. Used to create filenames for things like the log config and the signing key. - data_dir_path (str): The path where the data files are kept. Used to create + data_dir_path: The path where the data files are kept. Used to create filenames for things like the database and media store. - server_name (str): The server name. Used to initialise the server_name + server_name: The server name. Used to initialise the server_name config param, but also used in the names of some of the config files. - generate_secrets (bool): True if we should generate new secrets for things + generate_secrets: True if we should generate new secrets for things like the macaroon_secret_key. If False, these parameters will be left unset. - report_stats (bool|None): Initial setting for the report_stats setting. + report_stats: Initial setting for the report_stats setting. If None, report_stats will be left unset. - open_private_ports (bool): True to leave private ports (such as the non-TLS + open_private_ports: True to leave private ports (such as the non-TLS HTTP listener) open to the internet. - listeners (list(dict)|None): A list of descriptions of the listeners - synapse should start with each of which specifies a port (str), a list of + listeners: A list of descriptions of the listeners synapse should + start with each of which specifies a port (int), a list of resources (list(str)), tls (bool) and type (str). For example: [{ "port": 8448, @@ -403,16 +421,12 @@ class RootConfig: "type": "http", }], + tls_certificate_path: The path to the tls certificate. - database (str|None): The database type to configure, either `psycog2` - or `sqlite3`. - - tls_certificate_path (str|None): The path to the tls certificate. - - tls_private_key_path (str|None): The path to the tls private key. + tls_private_key_path: The path to the tls private key. Returns: - str: the yaml config file + The yaml config file """ return CONFIG_FILE_HEADER + "\n\n".join( @@ -432,12 +446,15 @@ class RootConfig: ) @classmethod - def load_config(cls, description, argv): + def load_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> TRootConfig: """Parse the commandline and config files Doesn't support config-file-generation: used by the worker apps. - Returns: Config object. + Returns: + Config object. """ config_parser = argparse.ArgumentParser(description=description) cls.add_arguments_to_parser(config_parser) @@ -446,7 +463,7 @@ class RootConfig: return obj @classmethod - def add_arguments_to_parser(cls, config_parser): + def add_arguments_to_parser(cls, config_parser: argparse.ArgumentParser) -> None: """Adds all the config flags to an ArgumentParser. Doesn't support config-file-generation: used by the worker apps. @@ -454,7 +471,7 @@ class RootConfig: Used for workers where we want to add extra flags/subcommands. Args: - config_parser (ArgumentParser): App description + config_parser: App description """ config_parser.add_argument( @@ -477,7 +494,9 @@ class RootConfig: cls.invoke_all_static("add_arguments", config_parser) @classmethod - def load_config_with_parser(cls, parser, argv): + def load_config_with_parser( + cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str] + ) -> Tuple[TRootConfig, argparse.Namespace]: """Parse the commandline and config files with the given parser Doesn't support config-file-generation: used by the worker apps. @@ -485,13 +504,12 @@ class RootConfig: Used for workers where we want to add extra flags/subcommands. Args: - parser (ArgumentParser) - argv (list[str]) + parser + argv Returns: - tuple[HomeServerConfig, argparse.Namespace]: Returns the parsed - config object and the parsed argparse.Namespace object from - `parser.parse_args(..)` + Returns the parsed config object and the parsed argparse.Namespace + object from parser.parse_args(..)` """ obj = cls() @@ -520,12 +538,15 @@ class RootConfig: return obj, config_args @classmethod - def load_or_generate_config(cls, description, argv): + def load_or_generate_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> Optional[TRootConfig]: """Parse the commandline and config files Supports generation of config files, so is used for the main homeserver app. - Returns: Config object, or None if --generate-config or --generate-keys was set + Returns: + Config object, or None if --generate-config or --generate-keys was set """ parser = argparse.ArgumentParser(description=description) parser.add_argument( @@ -680,16 +701,21 @@ class RootConfig: return obj - def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None): + def parse_config_dict( + self, + config_dict: Dict[str, Any], + config_dir_path: Optional[str] = None, + data_dir_path: Optional[str] = None, + ) -> None: """Read the information from the config dict into this Config object. Args: - config_dict (dict): Configuration data, as read from the yaml + config_dict: Configuration data, as read from the yaml - config_dir_path (str): The path where the config files are kept. Used to + config_dir_path: The path where the config files are kept. Used to create filenames for things like the log config and the signing key. - data_dir_path (str): The path where the data files are kept. Used to create + data_dir_path: The path where the data files are kept. Used to create filenames for things like the database and media store. """ self.invoke_all( @@ -699,17 +725,20 @@ class RootConfig: data_dir_path=data_dir_path, ) - def generate_missing_files(self, config_dict, config_dir_path): + def generate_missing_files( + self, config_dict: Dict[str, Any], config_dir_path: str + ) -> None: self.invoke_all("generate_files", config_dict, config_dir_path) -def read_config_files(config_files): +def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: """Read the config files into a dict Args: - config_files (iterable[str]): A list of the config files to read + config_files: A list of the config files to read - Returns: dict + Returns: + The configuration dictionary. """ specified_config = {} for config_file in config_files: @@ -733,17 +762,17 @@ def read_config_files(config_files): return specified_config -def find_config_files(search_paths): +def find_config_files(search_paths: List[str]) -> List[str]: """Finds config files using a list of search paths. If a path is a file then that file path is added to the list. If a search path is a directory then all the "*.yaml" files in that directory are added to the list in sorted order. Args: - search_paths(list(str)): A list of paths to search. + search_paths: A list of paths to search. Returns: - list(str): A list of file paths. + A list of file paths. """ config_files = [] @@ -777,7 +806,7 @@ def find_config_files(search_paths): return config_files -@attr.s +@attr.s(auto_attribs=True) class ShardedWorkerHandlingConfig: """Algorithm for choosing which instance is responsible for handling some sharded work. @@ -787,7 +816,7 @@ class ShardedWorkerHandlingConfig: below). """ - instances = attr.ib(type=List[str]) + instances: List[str] def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key.""" diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index c1d9069798..1eb5f5a68c 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -1,4 +1,18 @@ -from typing import Any, Iterable, List, Optional +import argparse +from typing import ( + Any, + Dict, + Iterable, + List, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +import jinja2 from synapse.config import ( account_validity, @@ -19,6 +33,7 @@ from synapse.config import ( logger, metrics, modules, + oembed, oidc, password_auth_providers, push, @@ -27,6 +42,7 @@ from synapse.config import ( registration, repository, retention, + room, room_directory, saml2, server, @@ -51,7 +67,9 @@ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str MISSING_REPORT_STATS_SPIEL: str MISSING_SERVER_NAME: str -def path_exists(file_path: str): ... +def path_exists(file_path: str) -> bool: ... + +TRootConfig = TypeVar("TRootConfig", bound="RootConfig") class RootConfig: server: server.ServerConfig @@ -61,6 +79,7 @@ class RootConfig: logging: logger.LoggingConfig ratelimiting: ratelimiting.RatelimitConfig media: repository.ContentRepositoryConfig + oembed: oembed.OembedConfig captcha: captcha.CaptchaConfig voip: voip.VoipConfig registration: registration.RegistrationConfig @@ -80,6 +99,7 @@ class RootConfig: authproviders: password_auth_providers.PasswordAuthProviderConfig push: push.PushConfig spamchecker: spam_checker.SpamCheckerConfig + room: room.RoomConfig groups: groups.GroupsConfig userdirectory: user_directory.UserDirectoryConfig consent: consent.ConsentConfig @@ -87,72 +107,85 @@ class RootConfig: servernotices: server_notices.ServerNoticesConfig roomdirectory: room_directory.RoomDirectoryConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig - tracer: tracer.TracerConfig + tracing: tracer.TracerConfig redis: redis.RedisConfig modules: modules.ModulesConfig caches: cache.CacheConfig federation: federation.FederationConfig retention: retention.RetentionConfig - config_classes: List = ... + config_classes: List[Type["Config"]] = ... def __init__(self) -> None: ... - def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ... + def invoke_all( + self, func_name: str, *args: Any, **kwargs: Any + ) -> MutableMapping[str, Any]: ... @classmethod def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ... - def __getattr__(self, item: str): ... def parse_config_dict( self, - config_dict: Any, - config_dir_path: Optional[Any] = ..., - data_dir_path: Optional[Any] = ..., + config_dict: Dict[str, Any], + config_dir_path: Optional[str] = ..., + data_dir_path: Optional[str] = ..., ) -> None: ... - read_config: Any = ... def generate_config( self, config_dir_path: str, data_dir_path: str, server_name: str, generate_secrets: bool = ..., - report_stats: Optional[str] = ..., + report_stats: Optional[bool] = ..., open_private_ports: bool = ..., listeners: Optional[Any] = ..., - database_conf: Optional[Any] = ..., tls_certificate_path: Optional[str] = ..., tls_private_key_path: Optional[str] = ..., - ): ... + ) -> str: ... @classmethod - def load_or_generate_config(cls, description: Any, argv: Any): ... + def load_or_generate_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> Optional[TRootConfig]: ... @classmethod - def load_config(cls, description: Any, argv: Any): ... + def load_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> TRootConfig: ... @classmethod - def add_arguments_to_parser(cls, config_parser: Any) -> None: ... + def add_arguments_to_parser( + cls, config_parser: argparse.ArgumentParser + ) -> None: ... @classmethod - def load_config_with_parser(cls, parser: Any, argv: Any): ... + def load_config_with_parser( + cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str] + ) -> Tuple[TRootConfig, argparse.Namespace]: ... def generate_missing_files( self, config_dict: dict, config_dir_path: str ) -> None: ... class Config: root: RootConfig + default_template_dir: str def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ... - def __getattr__(self, item: str, from_root: bool = ...): ... @staticmethod - def parse_size(value: Any): ... + def parse_size(value: Union[str, int]) -> int: ... @staticmethod - def parse_duration(value: Any): ... + def parse_duration(value: Union[str, int]) -> int: ... @staticmethod - def abspath(file_path: Optional[str]): ... + def abspath(file_path: Optional[str]) -> str: ... @classmethod - def path_exists(cls, file_path: str): ... + def path_exists(cls, file_path: str) -> bool: ... @classmethod - def check_file(cls, file_path: str, config_name: str): ... + def check_file(cls, file_path: str, config_name: str) -> str: ... @classmethod - def ensure_directory(cls, dir_path: str): ... + def ensure_directory(cls, dir_path: str) -> str: ... @classmethod - def read_file(cls, file_path: str, config_name: str): ... + def read_file(cls, file_path: str, config_name: str) -> str: ... + def read_template(self, filenames: str) -> jinja2.Template: ... + def read_templates( + self, + filenames: List[str], + custom_template_directories: Optional[Iterable[str]] = None, + ) -> List[jinja2.Template]: ... -def read_config_files(config_files: List[str]): ... -def find_config_files(search_paths: List[str]): ... +def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: ... +def find_config_files(search_paths: List[str]) -> List[str]: ... class ShardedWorkerHandlingConfig: instances: List[str] diff --git a/synapse/config/cache.py b/synapse/config/cache.py index d119427ad8..f054455534 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -15,7 +15,7 @@ import os import re import threading -from typing import Callable, Dict +from typing import Callable, Dict, Optional from synapse.python_dependencies import DependencyException, check_requirements @@ -217,7 +217,7 @@ class CacheConfig(Config): expiry_time = cache_config.get("expiry_time") if expiry_time: - self.expiry_time_msec = self.parse_duration(expiry_time) + self.expiry_time_msec: Optional[int] = self.parse_duration(expiry_time) else: self.expiry_time_msec = None diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index afd65fecd3..510b647c63 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -137,33 +137,14 @@ class EmailConfig(Config): if self.root.registration.account_threepid_delegate_email else ThreepidBehaviour.LOCAL ) - # Prior to Synapse v1.4.0, there was another option that defined whether Synapse would - # use an identity server to password reset tokens on its behalf. We now warn the user - # if they have this set and tell them to use the updated option, while using a default - # identity server in the process. - self.using_identity_server_from_trusted_list = False - if ( - not self.root.registration.account_threepid_delegate_email - and config.get("trust_identity_server_for_password_resets", False) is True - ): - # Use the first entry in self.trusted_third_party_id_servers instead - if self.trusted_third_party_id_servers: - # XXX: It's a little confusing that account_threepid_delegate_email is modified - # both in RegistrationConfig and here. We should factor this bit out - first_trusted_identity_server = self.trusted_third_party_id_servers[0] - - # trusted_third_party_id_servers does not contain a scheme whereas - # account_threepid_delegate_email is expected to. Presume https - self.root.registration.account_threepid_delegate_email = ( - "https://" + first_trusted_identity_server - ) - self.using_identity_server_from_trusted_list = True - else: - raise ConfigError( - "Attempted to use an identity server from" - '"trusted_third_party_id_servers" but it is empty.' - ) + if config.get("trust_identity_server_for_password_resets"): + raise ConfigError( + 'The config option "trust_identity_server_for_password_resets" ' + 'has been replaced by "account_threepid_delegate". ' + "Please consult the sample config at docs/sample_config.yaml for " + "details and update your config file." + ) self.local_threepid_handling_disabled_due_to_email_config = False if ( diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py index 9d295f5856..24c3ef01fc 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt.py @@ -31,6 +31,8 @@ class JWTConfig(Config): self.jwt_secret = jwt_config["secret"] self.jwt_algorithm = jwt_config["algorithm"] + self.jwt_subject_claim = jwt_config.get("subject_claim", "sub") + # The issuer and audiences are optional, if provided, it is asserted # that the claims exist on the JWT. self.jwt_issuer = jwt_config.get("issuer") @@ -46,6 +48,7 @@ class JWTConfig(Config): self.jwt_enabled = False self.jwt_secret = None self.jwt_algorithm = None + self.jwt_subject_claim = None self.jwt_issuer = None self.jwt_audiences = None @@ -88,6 +91,12 @@ class JWTConfig(Config): # #algorithm: "provided-by-your-issuer" + # Name of the claim containing a unique identifier for the user. + # + # Optional, defaults to `sub`. + # + #subject_claim: "sub" + # The issuer to validate the "iss" claim against. # # Optional, if provided the "iss" claim will be required and diff --git a/synapse/config/key.py b/synapse/config/key.py index 015dbb8a67..035ee2416b 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -16,6 +16,7 @@ import hashlib import logging import os +from typing import Any, Dict import attr import jsonschema @@ -312,7 +313,7 @@ class KeyConfig(Config): ) return keys - def generate_files(self, config, config_dir_path): + def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None: if "signing_key" in config: return diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 5252e61a99..63aab0babe 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -18,7 +18,7 @@ import os import sys import threading from string import Template -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict import yaml from zope.interface import implementer @@ -185,7 +185,7 @@ class LoggingConfig(Config): help=argparse.SUPPRESS, ) - def generate_files(self, config, config_dir_path): + def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None: log_config = config.get("log_config") if log_config and not os.path.exists(log_config): log_file = self.abspath("homeserver.log") diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 5379e80715..61e569d412 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -39,9 +39,7 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) - self.trusted_third_party_id_servers = config.get( - "trusted_third_party_id_servers", ["matrix.org", "vector.im"] - ) + account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") @@ -114,25 +112,32 @@ class RegistrationConfig(Config): session_lifetime = self.parse_duration(session_lifetime) self.session_lifetime = session_lifetime - # The `access_token_lifetime` applies for tokens that can be renewed + # The `refreshable_access_token_lifetime` applies for tokens that can be renewed # using a refresh token, as per MSC2918. If it is `None`, the refresh # token mechanism is disabled. # # Since it is incompatible with the `session_lifetime` mechanism, it is set to # `None` by default if a `session_lifetime` is set. - access_token_lifetime = config.get( - "access_token_lifetime", "5m" if session_lifetime is None else None + refreshable_access_token_lifetime = config.get( + "refreshable_access_token_lifetime", + "5m" if session_lifetime is None else None, ) - if access_token_lifetime is not None: - access_token_lifetime = self.parse_duration(access_token_lifetime) - self.access_token_lifetime = access_token_lifetime + if refreshable_access_token_lifetime is not None: + refreshable_access_token_lifetime = self.parse_duration( + refreshable_access_token_lifetime + ) + self.refreshable_access_token_lifetime = refreshable_access_token_lifetime - if session_lifetime is not None and access_token_lifetime is not None: + if ( + session_lifetime is not None + and refreshable_access_token_lifetime is not None + ): raise ConfigError( "The refresh token mechanism is incompatible with the " "`session_lifetime` option. Consider disabling the " "`session_lifetime` option or disabling the refresh token " - "mechanism by removing the `access_token_lifetime` option." + "mechanism by removing the `refreshable_access_token_lifetime` " + "option." ) # The fallback template used for authenticating using a registration token diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 56981cac79..57316c59b6 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -1,4 +1,5 @@ # Copyright 2018 New Vector Ltd +# Copyright 2021 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + +from synapse.types import JsonDict from synapse.util import glob_to_regex from ._base import Config, ConfigError @@ -20,7 +24,7 @@ from ._base import Config, ConfigError class RoomDirectoryConfig(Config): section = "roomdirectory" - def read_config(self, config, **kwargs): + def read_config(self, config, **kwargs) -> None: self.enable_room_list_search = config.get("enable_room_list_search", True) alias_creation_rules = config.get("alias_creation_rules") @@ -47,7 +51,7 @@ class RoomDirectoryConfig(Config): _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"}) ] - def generate_config_section(self, config_dir_path, server_name, **kwargs): + def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str: return """ # Uncomment to disable searching the public room list. When disabled # blocks searching local and remote room lists for local and remote @@ -113,16 +117,16 @@ class RoomDirectoryConfig(Config): # action: allow """ - def is_alias_creation_allowed(self, user_id, room_id, alias): + def is_alias_creation_allowed(self, user_id: str, room_id: str, alias: str) -> bool: """Checks if the given user is allowed to create the given alias Args: - user_id (str) - room_id (str) - alias (str) + user_id: The user to check. + room_id: The room ID for the alias. + alias: The alias being created. Returns: - boolean: True if user is allowed to create the alias + True if user is allowed to create the alias """ for rule in self._alias_creation_rules: if rule.matches(user_id, room_id, [alias]): @@ -130,16 +134,18 @@ class RoomDirectoryConfig(Config): return False - def is_publishing_room_allowed(self, user_id, room_id, aliases): + def is_publishing_room_allowed( + self, user_id: str, room_id: str, aliases: List[str] + ) -> bool: """Checks if the given user is allowed to publish the room Args: - user_id (str) - room_id (str) - aliases (list[str]): any local aliases associated with the room + user_id: The user ID publishing the room. + room_id: The room being published. + aliases: any local aliases associated with the room Returns: - boolean: True if user can publish room + True if user can publish room """ for rule in self._room_list_publication_rules: if rule.matches(user_id, room_id, aliases): @@ -153,11 +159,11 @@ class _RoomDirectoryRule: creating an alias or publishing a room. """ - def __init__(self, option_name, rule): + def __init__(self, option_name: str, rule: JsonDict): """ Args: - option_name (str): Name of the config option this rule belongs to - rule (dict): The rule as specified in the config + option_name: Name of the config option this rule belongs to + rule: The rule as specified in the config """ action = rule["action"] @@ -181,18 +187,18 @@ class _RoomDirectoryRule: except Exception as e: raise ConfigError("Failed to parse glob into regex") from e - def matches(self, user_id, room_id, aliases): + def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool: """Tests if this rule matches the given user_id, room_id and aliases. Args: - user_id (str) - room_id (str) - aliases (list[str]): The associated aliases to the room. Will be a - single element for testing alias creation, and can be empty for - testing room publishing. + user_id: The user ID to check. + room_id: The room ID to check. + aliases: The associated aliases to the room. Will be a single element + for testing alias creation, and can be empty for testing room + publishing. Returns: - boolean + True if the rule matches. """ # Note: The regexes are anchored at both ends diff --git a/synapse/config/server.py b/synapse/config/server.py index 7bc0030a9e..8445e9dd05 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -421,7 +421,7 @@ class ServerConfig(Config): # before redacting them. redaction_retention_period = config.get("redaction_retention_period", "7d") if redaction_retention_period is not None: - self.redaction_retention_period = self.parse_duration( + self.redaction_retention_period: Optional[int] = self.parse_duration( redaction_retention_period ) else: @@ -430,7 +430,7 @@ class ServerConfig(Config): # How long to keep entries in the `users_ips` table. user_ips_max_age = config.get("user_ips_max_age", "28d") if user_ips_max_age is not None: - self.user_ips_max_age = self.parse_duration(user_ips_max_age) + self.user_ips_max_age: Optional[int] = self.parse_duration(user_ips_max_age) else: self.user_ips_max_age = None diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 6227434bac..4ca111618f 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -14,7 +14,6 @@ import logging import os -from datetime import datetime from typing import List, Optional, Pattern from OpenSSL import SSL, crypto @@ -133,55 +132,6 @@ class TlsConfig(Config): self.tls_certificate: Optional[crypto.X509] = None self.tls_private_key: Optional[crypto.PKey] = None - def is_disk_cert_valid(self, allow_self_signed=True): - """ - Is the certificate we have on disk valid, and if so, for how long? - - Args: - allow_self_signed (bool): Should we allow the certificate we - read to be self signed? - - Returns: - int: Days remaining of certificate validity. - None: No certificate exists. - """ - if not os.path.exists(self.tls_certificate_file): - return None - - try: - with open(self.tls_certificate_file, "rb") as f: - cert_pem = f.read() - except Exception as e: - raise ConfigError( - "Failed to read existing certificate file %s: %s" - % (self.tls_certificate_file, e) - ) - - try: - tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) - except Exception as e: - raise ConfigError( - "Failed to parse existing certificate file %s: %s" - % (self.tls_certificate_file, e) - ) - - if not allow_self_signed: - if tls_certificate.get_subject() == tls_certificate.get_issuer(): - raise ValueError( - "TLS Certificate is self signed, and this is not permitted" - ) - - # YYYYMMDDhhmmssZ -- in UTC - expiry_data = tls_certificate.get_notAfter() - if expiry_data is None: - raise ValueError( - "TLS Certificate has no expiry date, and this is not permitted" - ) - expires_on = datetime.strptime(expiry_data.decode("ascii"), "%Y%m%d%H%M%SZ") - now = datetime.utcnow() - days_remaining = (expires_on - now).days - return days_remaining - def read_certificate_from_disk(self): """ Read the certificates and private key from disk. @@ -263,8 +213,8 @@ class TlsConfig(Config): # #federation_certificate_verification_whitelist: # - lon.example.com - # - *.domain.com - # - *.onion + # - "*.domain.com" + # - "*.onion" # List of custom certificate authorities for federation traffic. # @@ -295,7 +245,7 @@ class TlsConfig(Config): cert_path = self.tls_certificate_file logger.info("Loading TLS certificate from %s", cert_path) cert_pem = self.read_file(cert_path, "tls_certificate_path") - cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem.encode()) return cert diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index 2552f688d0..6d6678c7e4 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -53,8 +53,8 @@ class UserDirectoryConfig(Config): # indexes were (re)built was before Synapse 1.44, you'll have to # rebuild the indexes in order to search through all known users. # These indexes are built the first time Synapse starts; admins can - # manually trigger a rebuild following the instructions at - # https://matrix-org.github.io/synapse/latest/user_directory.html + # manually trigger a rebuild via API following the instructions at + # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/background_updates.html#run # # Uncomment to return search results containing all known users, even if that # user does not share a room with the requester. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index f641ab7ef5..4cda439ad9 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -1,5 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017, 2018 New Vector Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -120,16 +119,6 @@ class VerifyJsonRequest: key_ids=key_ids, ) - def to_fetch_key_request(self) -> "_FetchKeyRequest": - """Create a key fetch request for all keys needed to satisfy the - verification request. - """ - return _FetchKeyRequest( - server_name=self.server_name, - minimum_valid_until_ts=self.minimum_valid_until_ts, - key_ids=self.key_ids, - ) - class KeyLookupError(ValueError): pass @@ -179,8 +168,22 @@ class Keyring: clock=hs.get_clock(), process_batch_callback=self._inner_fetch_key_requests, ) - self.verify_key = get_verify_key(hs.signing_key) - self.hostname = hs.hostname + + self._hostname = hs.hostname + + # build a FetchKeyResult for each of our own keys, to shortcircuit the + # fetcher. + self._local_verify_keys: Dict[str, FetchKeyResult] = {} + for key_id, key in hs.config.key.old_signing_keys.items(): + self._local_verify_keys[key_id] = FetchKeyResult( + verify_key=key, valid_until_ts=key.expired_ts + ) + + vk = get_verify_key(hs.signing_key) + self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult( + verify_key=vk, + valid_until_ts=2 ** 63, # fake future timestamp + ) async def verify_json_for_server( self, @@ -267,22 +270,32 @@ class Keyring: Codes.UNAUTHORIZED, ) - # If we are the originating server don't fetch verify key for self over federation - if verify_request.server_name == self.hostname: - await self._process_json(self.verify_key, verify_request) - return + found_keys: Dict[str, FetchKeyResult] = {} - # Add the keys we need to verify to the queue for retrieval. We queue - # up requests for the same server so we don't end up with many in flight - # requests for the same keys. - key_request = verify_request.to_fetch_key_request() - found_keys_by_server = await self._server_queue.add_to_queue( - key_request, key=verify_request.server_name - ) + # If we are the originating server, short-circuit the key-fetch for any keys + # we already have + if verify_request.server_name == self._hostname: + for key_id in verify_request.key_ids: + if key_id in self._local_verify_keys: + found_keys[key_id] = self._local_verify_keys[key_id] + + key_ids_to_find = set(verify_request.key_ids) - found_keys.keys() + if key_ids_to_find: + # Add the keys we need to verify to the queue for retrieval. We queue + # up requests for the same server so we don't end up with many in flight + # requests for the same keys. + key_request = _FetchKeyRequest( + server_name=verify_request.server_name, + minimum_valid_until_ts=verify_request.minimum_valid_until_ts, + key_ids=list(key_ids_to_find), + ) + found_keys_by_server = await self._server_queue.add_to_queue( + key_request, key=verify_request.server_name + ) - # Since we batch up requests the returned set of keys may contain keys - # from other servers, so we pull out only the ones we care about.s - found_keys = found_keys_by_server.get(verify_request.server_name, {}) + # Since we batch up requests the returned set of keys may contain keys + # from other servers, so we pull out only the ones we care about. + found_keys.update(found_keys_by_server.get(verify_request.server_name, {})) # Verify each signature we got valid keys for, raising if we can't # verify any of them. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 4f409f31e1..eb39e0ae32 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -128,14 +128,12 @@ class EventBuilder: ) format_version = self.room_version.event_format + # The types of auth/prev events changes between event versions. + prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] + auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] if format_version == EventFormatVersions.V1: - # The types of auth/prev events changes between event versions. - auth_events: Union[ - List[str], List[Tuple[str, Dict[str, str]]] - ] = await self._store.add_event_hashes(auth_event_ids) - prev_events: Union[ - List[str], List[Tuple[str, Dict[str, str]]] - ] = await self._store.add_event_hashes(prev_event_ids) + auth_events = await self._store.add_event_hashes(auth_event_ids) + prev_events = await self._store.add_event_hashes(prev_event_ids) else: auth_events = auth_event_ids prev_events = prev_event_ids diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 6fa631aa1d..e5967c995e 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -392,15 +393,16 @@ class EventClientSerializer: self, event: Union[JsonDict, EventBase], time_now: int, - bundle_aggregations: bool = True, + bundle_relations: bool = True, **kwargs: Any, ) -> JsonDict: """Serializes a single event. Args: - event + event: The event being serialized. time_now: The current time in milliseconds - bundle_aggregations: Whether to bundle in related events + bundle_relations: Whether to include the bundled relations for this + event. **kwargs: Arguments to pass to `serialize_event` Returns: @@ -410,77 +412,93 @@ class EventClientSerializer: if not isinstance(event, EventBase): return event - event_id = event.event_id serialized_event = serialize_event(event, time_now, **kwargs) # If MSC1849 is enabled then we need to look if there are any relations # we need to bundle in with the event. # Do not bundle relations if the event has been redacted if not event.internal_metadata.is_redacted() and ( - self._msc1849_enabled and bundle_aggregations + self._msc1849_enabled and bundle_relations ): - annotations = await self.store.get_aggregation_groups_for_event(event_id) - references = await self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCE, direction="f" - ) - - if annotations.chunk: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.ANNOTATION] = annotations.to_dict() - - if references.chunk: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REFERENCE] = references.to_dict() - - edit = None - if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id) - - if edit: - # If there is an edit replace the content, preserving existing - # relations. - - # Ensure we take copies of the edit content, otherwise we risk modifying - # the original event. - edit_content = edit.content.copy() - - # Unfreeze the event content if necessary, so that we may modify it below - edit_content = unfreeze(edit_content) - serialized_event["content"] = edit_content.get("m.new_content", {}) - - # Check for existing relations - relations = event.content.get("m.relates_to") - if relations: - # Keep the relations, ensuring we use a dict copy of the original - serialized_event["content"]["m.relates_to"] = relations.copy() - else: - serialized_event["content"].pop("m.relates_to", None) - - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REPLACE] = { - "event_id": edit.event_id, - "origin_server_ts": edit.origin_server_ts, - "sender": edit.sender, - } - - # If this event is the start of a thread, include a summary of the replies. - if self._msc3440_enabled: - ( - thread_count, - latest_thread_event, - ) = await self.store.get_thread_summary(event_id) - if latest_thread_event: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. - "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=False - ), - "count": thread_count, - } + await self._injected_bundled_relations(event, time_now, serialized_event) return serialized_event + async def _injected_bundled_relations( + self, event: EventBase, time_now: int, serialized_event: JsonDict + ) -> None: + """Potentially injects bundled relations into the unsigned portion of the serialized event. + + Args: + event: The event being serialized. + time_now: The current time in milliseconds + serialized_event: The serialized event which may be modified. + + """ + event_id = event.event_id + + # The bundled relations to include. + relations = {} + + annotations = await self.store.get_aggregation_groups_for_event(event_id) + if annotations.chunk: + relations[RelationTypes.ANNOTATION] = annotations.to_dict() + + references = await self.store.get_relations_for_event( + event_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + relations[RelationTypes.REFERENCE] = references.to_dict() + + edit = None + if event.type == EventTypes.Message: + edit = await self.store.get_applicable_edit(event_id) + + if edit: + # If there is an edit replace the content, preserving existing + # relations. + + # Ensure we take copies of the edit content, otherwise we risk modifying + # the original event. + edit_content = edit.content.copy() + + # Unfreeze the event content if necessary, so that we may modify it below + edit_content = unfreeze(edit_content) + serialized_event["content"] = edit_content.get("m.new_content", {}) + + # Check for existing relations + relates_to = event.content.get("m.relates_to") + if relates_to: + # Keep the relations, ensuring we use a dict copy of the original + serialized_event["content"]["m.relates_to"] = relates_to.copy() + else: + serialized_event["content"].pop("m.relates_to", None) + + relations[RelationTypes.REPLACE] = { + "event_id": edit.event_id, + "origin_server_ts": edit.origin_server_ts, + "sender": edit.sender, + } + + # If this event is the start of a thread, include a summary of the replies. + if self._msc3440_enabled: + ( + thread_count, + latest_thread_event, + ) = await self.store.get_thread_summary(event_id) + if latest_thread_event: + relations[RelationTypes.THREAD] = { + # Don't bundle relations as this could recurse forever. + "latest_event": await self.serialize_event( + latest_thread_event, time_now, bundle_relations=False + ), + "count": thread_count, + } + + # If any bundled relations were found, include them. + if relations: + serialized_event["unsigned"].setdefault("m.relations", {}).update(relations) + async def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any ) -> List[JsonDict]: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 670186f548..3b85b135e0 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -277,6 +277,58 @@ class FederationClient(FederationBase): return pdus + async def get_pdu_from_destination_raw( + self, + destination: str, + event_id: str, + room_version: RoomVersion, + outlier: bool = False, + timeout: Optional[int] = None, + ) -> Optional[EventBase]: + """Requests the PDU with given origin and ID from the remote home + server. Does not have any caching or rate limiting! + + Args: + destination: Which homeserver to query + event_id: event to fetch + room_version: version of the room + outlier: Indicates whether the PDU is an `outlier`, i.e. if + it's from an arbitrary point in the context as opposed to part + of the current block of PDUs. Defaults to `False` + timeout: How long to try (in ms) each destination for before + moving to the next destination. None indicates no timeout. + + Returns: + The requested PDU, or None if we were unable to find it. + + Raises: + SynapseError, NotRetryingDestination, FederationDeniedError + """ + transaction_data = await self.transport_layer.get_event( + destination, event_id, timeout=timeout + ) + + logger.debug( + "retrieved event id %s from %s: %r", + event_id, + destination, + transaction_data, + ) + + pdu_list: List[EventBase] = [ + event_from_pdu_json(p, room_version, outlier=outlier) + for p in transaction_data["pdus"] + ] + + if pdu_list and pdu_list[0]: + pdu = pdu_list[0] + + # Check signatures are correct. + signed_pdu = await self._check_sigs_and_hash(room_version, pdu) + return signed_pdu + + return None + async def get_pdu( self, destinations: Iterable[str], @@ -321,30 +373,14 @@ class FederationClient(FederationBase): continue try: - transaction_data = await self.transport_layer.get_event( - destination, event_id, timeout=timeout - ) - - logger.debug( - "retrieved event id %s from %s: %r", - event_id, - destination, - transaction_data, + signed_pdu = await self.get_pdu_from_destination_raw( + destination=destination, + event_id=event_id, + room_version=room_version, + outlier=outlier, + timeout=timeout, ) - pdu_list: List[EventBase] = [ - event_from_pdu_json(p, room_version, outlier=outlier) - for p in transaction_data["pdus"] - ] - - if pdu_list and pdu_list[0]: - pdu = pdu_list[0] - - # Check signatures are correct. - signed_pdu = await self._check_sigs_and_hash(room_version, pdu) - - break - pdu_attempts[destination] = now except SynapseError as e: diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 53f99031b1..a87896e538 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -40,6 +40,8 @@ from typing import TYPE_CHECKING, Optional, Tuple from signedjson.sign import sign_json +from twisted.internet.defer import Deferred + from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict, get_domain_from_id @@ -166,7 +168,7 @@ class GroupAttestionRenewer: return {} - def _start_renew_attestations(self) -> None: + def _start_renew_attestations(self) -> "Deferred[None]": return run_as_background_process("renew_attestations", self._renew_attestations) async def _renew_attestations(self) -> None: diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index be3203ac80..85157a138b 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -234,7 +234,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): @abc.abstractmethod def write_invite( - self, room_id: str, event: EventBase, state: StateMap[dict] + self, room_id: str, event: EventBase, state: StateMap[EventBase] ) -> None: """Write an invite for the room, with associated invite state. @@ -248,7 +248,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): @abc.abstractmethod def write_knock( - self, room_id: str, event: EventBase, state: StateMap[dict] + self, room_id: str, event: EventBase, state: StateMap[EventBase] ) -> None: """Write a knock for the room, with associated knock state. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index ddc9105ee9..9abdad262b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -188,7 +188,7 @@ class ApplicationServicesHandler: self, stream_key: str, new_token: Union[int, RoomStreamToken], - users: Optional[Collection[Union[str, UserID]]] = None, + users: Collection[Union[str, UserID]], ) -> None: """ This is called by the notifier in the background when an ephemeral event is handled @@ -203,7 +203,9 @@ class ApplicationServicesHandler: value for `stream_key` will cause this function to return early. Ephemeral events will only be pushed to appservices that have opted into - them. + receiving them by setting `push_ephemeral` to true in their registration + file. Note that while MSC2409 is experimental, this option is called + `de.sorunome.msc2409.push_ephemeral`. Appservices will only receive ephemeral events that fall within their registered user and room namespaces. @@ -214,6 +216,7 @@ class ApplicationServicesHandler: if not self.notify_appservices: return + # Ignore any unsupported streams if stream_key not in ("typing_key", "receipt_key", "presence_key"): return @@ -230,18 +233,25 @@ class ApplicationServicesHandler: # Additional context: https://github.com/matrix-org/synapse/pull/11137 assert isinstance(new_token, int) + # Check whether there are any appservices which have registered to receive + # ephemeral events. + # + # Note that whether these events are actually relevant to these appservices + # is decided later on. services = [ service for service in self.store.get_app_services() if service.supports_ephemeral ] if not services: + # Bail out early if none of the target appservices have explicitly registered + # to receive these ephemeral events. return # We only start a new background process if necessary rather than # optimistically (to cut down on overhead). self._notify_interested_services_ephemeral( - services, stream_key, new_token, users or [] + services, stream_key, new_token, users ) @wrap_as_background_process("notify_interested_services_ephemeral") @@ -252,7 +262,7 @@ class ApplicationServicesHandler: new_token: int, users: Collection[Union[str, UserID]], ) -> None: - logger.debug("Checking interested services for %s" % (stream_key)) + logger.debug("Checking interested services for %s", stream_key) with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: if stream_key == "typing_key": @@ -345,6 +355,9 @@ class ApplicationServicesHandler: Args: service: The application service to check for which events it should receive. + new_token: A receipts event stream token. Purely used to double-check that the + from_token we pull from the database isn't greater than or equal to this + token. Prevents accidentally duplicating work. Returns: A list of JSON dictionaries containing data derived from the read receipts that @@ -382,6 +395,9 @@ class ApplicationServicesHandler: Args: service: The application service that ephemeral events are being sent to. users: The users that should receive the presence update. + new_token: A presence update stream token. Purely used to double-check that the + from_token we pull from the database isn't greater than or equal to this + token. Prevents accidentally duplicating work. Returns: A list of json dictionaries containing data derived from the presence events diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 60e59d11a0..4b66a9862f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -790,10 +790,10 @@ class AuthHandler: ( new_refresh_token, new_refresh_token_id, - ) = await self.get_refresh_token_for_user_id( + ) = await self.create_refresh_token_for_user_id( user_id=existing_token.user_id, device_id=existing_token.device_id ) - access_token = await self.get_access_token_for_user_id( + access_token = await self.create_access_token_for_user_id( user_id=existing_token.user_id, device_id=existing_token.device_id, valid_until_ms=valid_until_ms, @@ -832,7 +832,7 @@ class AuthHandler: return True - async def get_refresh_token_for_user_id( + async def create_refresh_token_for_user_id( self, user_id: str, device_id: str, @@ -855,7 +855,7 @@ class AuthHandler: ) return refresh_token, refresh_token_id - async def get_access_token_for_user_id( + async def create_access_token_for_user_id( self, user_id: str, device_id: Optional[str], @@ -1828,13 +1828,6 @@ def load_single_legacy_password_auth_provider( logger.error("Error while initializing %r: %s", module, e) raise - # The known hooks. If a module implements a method who's name appears in this set - # we'll want to register it - password_auth_provider_methods = { - "check_3pid_auth", - "on_logged_out", - } - # All methods that the module provides should be async, but this wasn't enforced # in the old module system, so we wrap them if needed def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: @@ -1919,11 +1912,14 @@ def load_single_legacy_password_auth_provider( return run - # populate hooks with the implemented methods, wrapped with async_wrapper - hooks = { - hook: async_wrapper(getattr(provider, hook, None)) - for hook in password_auth_provider_methods - } + # If the module has these methods implemented, then we pull them out + # and register them as hooks. + check_3pid_auth_hook: Optional[CHECK_3PID_AUTH_CALLBACK] = async_wrapper( + getattr(provider, "check_3pid_auth", None) + ) + on_logged_out_hook: Optional[ON_LOGGED_OUT_CALLBACK] = async_wrapper( + getattr(provider, "on_logged_out", None) + ) supported_login_types = {} # call get_supported_login_types and add that to the dict @@ -1950,7 +1946,11 @@ def load_single_legacy_password_auth_provider( # need to use a tuple here for ("password",) not a list since lists aren't hashable auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password - api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers) + api.register_password_auth_provider_callbacks( + check_3pid_auth=check_3pid_auth_hook, + on_logged_out=on_logged_out_hook, + auth_checkers=auth_checkers, + ) CHECK_3PID_AUTH_CALLBACK = Callable[ diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index b6a2a34ab7..b582266af9 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -89,6 +89,13 @@ class DeviceMessageHandler: ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: + """ + Handle receiving to-device messages from remote homeservers. + + Args: + origin: The remote homeserver. + content: The JSON dictionary containing the to-device messages. + """ local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): @@ -135,12 +142,16 @@ class DeviceMessageHandler: message_type, sender_user_id, by_device ) - stream_id = await self.store.add_messages_from_remote_to_device_inbox( + # Add messages to the database. + # Retrieve the stream id of the last-processed to-device message. + last_stream_id = await self.store.add_messages_from_remote_to_device_inbox( origin, message_id, local_messages ) + # Notify listeners that there are new to-device messages to process, + # handing them the latest stream id. self.notifier.on_new_event( - "to_device_key", stream_id, users=local_messages.keys() + "to_device_key", last_stream_id, users=local_messages.keys() ) async def _check_for_unknown_devices( @@ -195,6 +206,14 @@ class DeviceMessageHandler: message_type: str, messages: Dict[str, Dict[str, JsonDict]], ) -> None: + """ + Handle a request from a user to send to-device message(s). + + Args: + requester: The user that is sending the to-device messages. + message_type: The type of to-device messages that are being sent. + messages: A dictionary containing recipients mapped to messages intended for them. + """ sender_user_id = requester.user.to_string() message_id = random_string(16) @@ -257,12 +276,16 @@ class DeviceMessageHandler: "org.matrix.opentracing_context": json_encoder.encode(context), } - stream_id = await self.store.add_messages_to_device_inbox( + # Add messages to the database. + # Retrieve the stream id of the last-processed to-device message. + last_stream_id = await self.store.add_messages_to_device_inbox( local_messages, remote_edu_contents ) + # Notify listeners that there are new to-device messages to process, + # handing them the latest stream id. self.notifier.on_new_event( - "to_device_key", stream_id, users=local_messages.keys() + "to_device_key", last_stream_id, users=local_messages.keys() ) if self.federation_sender: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 8ca5f60b1c..7ee5c47fd9 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -204,6 +204,10 @@ class DirectoryHandler: ) room_id = await self._delete_association(room_alias) + if room_id is None: + # It's possible someone else deleted the association after the + # checks above, but before we did the deletion. + raise NotFoundError("Unknown room alias") try: await self._update_canonical_alias(requester, user_id, room_id, room_alias) @@ -225,7 +229,7 @@ class DirectoryHandler: ) await self._delete_association(room_alias) - async def _delete_association(self, room_alias: RoomAlias) -> str: + async def _delete_association(self, room_alias: RoomAlias) -> Optional[str]: if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 1f64534a8a..b4ff935546 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -124,7 +124,7 @@ class EventStreamHandler: as_client_event=as_client_event, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. - bundle_aggregations=False, + bundle_relations=False, ) chunk = { diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 1a1cd93b1a..9917613298 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -981,8 +981,6 @@ class FederationEventHandler: origin, event, context, - state=state, - backfilled=backfilled, ) except AuthError as e: # FIXME richvdh 2021/10/07 I don't think this is reachable. Let's log it @@ -1332,8 +1330,6 @@ class FederationEventHandler: origin: str, event: EventBase, context: EventContext, - state: Optional[Iterable[EventBase]] = None, - backfilled: bool = False, ) -> EventContext: """ Checks whether an event should be rejected (for failing auth checks). @@ -1344,12 +1340,6 @@ class FederationEventHandler: context: The event context. - state: - The state events used to check the event for soft-fail. If this is - not provided the current state events will be used. - - backfilled: True if the event was backfilled. - Returns: The updated context object. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 3dbe611f95..c83eaea359 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -464,15 +464,6 @@ class IdentityHandler: if next_link: params["next_link"] = next_link - if self.hs.config.email.using_identity_server_from_trusted_list: - # Warn that a deprecated config option is in use - logger.warning( - 'The config option "trust_identity_server_for_password_resets" ' - 'has been replaced by "account_threepid_delegate". ' - "Please consult the sample config at docs/sample_config.yaml for " - "details and update your config file." - ) - try: data = await self.http_client.post_json_get_json( id_server + "/_matrix/identity/api/v1/validate/email/requestToken", @@ -517,15 +508,6 @@ class IdentityHandler: if next_link: params["next_link"] = next_link - if self.hs.config.email.using_identity_server_from_trusted_list: - # Warn that a deprecated config option is in use - logger.warning( - 'The config option "trust_identity_server_for_password_resets" ' - 'has been replaced by "account_threepid_delegate". ' - "Please consult the sample config at docs/sample_config.yaml for " - "details and update your config file." - ) - try: data = await self.http_client.post_json_get_json( id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken", diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6a6c468cb7..67e557aeaf 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -252,7 +252,7 @@ class MessageHandler: now, # We don't bother bundling aggregations in when asked for state # events, as clients won't use them. - bundle_aggregations=False, + bundle_relations=False, ) return events @@ -1001,13 +1001,52 @@ class EventCreationHandler: ) self.validator.validate_new(event, self.config) + await self._validate_event_relation(event) + logger.debug("Created event %s", event.event_id) + + return event, context + + async def _validate_event_relation(self, event: EventBase) -> None: + """ + Ensure the relation data on a new event is not bogus. + + Args: + event: The event being created. + + Raises: + SynapseError if the event is invalid. + """ + + relation = event.content.get("m.relates_to") + if not relation: + return + + relation_type = relation.get("rel_type") + if not relation_type: + return + + # Ensure the parent is real. + relates_to = relation.get("event_id") + if not relates_to: + return + + parent_event = await self.store.get_event(relates_to, allow_none=True) + if parent_event: + # And in the same room. + if parent_event.room_id != event.room_id: + raise SynapseError(400, "Relations must be in the same room") + + else: + # There must be some reason that the client knows the event exists, + # see if there are existing relations. If so, assume everything is fine. + if not await self.store.event_is_target_of_relation(relates_to): + # Otherwise, the client can't know about the parent event! + raise SynapseError(400, "Can't send relation to unknown event") # If this event is an annotation then we check that that the sender # can't annotate the same way twice (e.g. stops users from liking an # event multiple times). - relation = event.content.get("m.relates_to", {}) - if relation.get("rel_type") == RelationTypes.ANNOTATION: - relates_to = relation["event_id"] + if relation_type == RelationTypes.ANNOTATION: aggregation_key = relation["key"] already_exists = await self.store.has_user_annotated_event( @@ -1016,9 +1055,12 @@ class EventCreationHandler: if already_exists: raise SynapseError(400, "Can't send same reaction twice") - logger.debug("Created event %s", event.event_id) - - return event, context + # Don't attempt to start a thread if the parent event is a relation. + elif relation_type == RelationTypes.THREAD: + if await self.store.event_includes_relation(relates_to): + raise SynapseError( + 400, "Cannot start threads from an event with a relation" + ) @measure_func("handle_new_client_event") async def handle_new_client_event( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index abfe7be0e3..cd64142735 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Optional, Set +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set import attr @@ -22,7 +22,7 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter -from synapse.logging.context import run_in_background +from synapse.handlers.room import ShutdownRoomResponse from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig @@ -56,11 +56,62 @@ class PurgeStatus: STATUS_FAILED: "failed", } + # Save the error message if an error occurs + error: str = "" + # Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}. status: int = STATUS_ACTIVE def asdict(self) -> JsonDict: - return {"status": PurgeStatus.STATUS_TEXT[self.status]} + ret = {"status": PurgeStatus.STATUS_TEXT[self.status]} + if self.error: + ret["error"] = self.error + return ret + + +@attr.s(slots=True, auto_attribs=True) +class DeleteStatus: + """Object tracking the status of a delete room request + + This class contains information on the progress of a delete room request, for + return by get_delete_status. + """ + + STATUS_PURGING = 0 + STATUS_COMPLETE = 1 + STATUS_FAILED = 2 + STATUS_SHUTTING_DOWN = 3 + + STATUS_TEXT = { + STATUS_PURGING: "purging", + STATUS_COMPLETE: "complete", + STATUS_FAILED: "failed", + STATUS_SHUTTING_DOWN: "shutting_down", + } + + # Tracks whether this request has completed. + # One of STATUS_{PURGING,COMPLETE,FAILED,SHUTTING_DOWN}. + status: int = STATUS_PURGING + + # Save the error message if an error occurs + error: str = "" + + # Saves the result of an action to give it back to REST API + shutdown_room: ShutdownRoomResponse = { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } + + def asdict(self) -> JsonDict: + ret = { + "status": DeleteStatus.STATUS_TEXT[self.status], + "shutdown_room": self.shutdown_room, + } + if self.error: + ret["error"] = self.error + return ret class PaginationHandler: @@ -70,6 +121,9 @@ class PaginationHandler: paginating during a purge. """ + # when to remove a completed deletion/purge from the results map + CLEAR_PURGE_AFTER_MS = 1000 * 3600 * 24 # 24 hours + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() @@ -78,11 +132,18 @@ class PaginationHandler: self.state_store = self.storage.state self.clock = hs.get_clock() self._server_name = hs.hostname + self._room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_lock = ReadWriteLock() + # IDs of rooms in which there currently an active purge *or delete* operation. self._purges_in_progress_by_room: Set[str] = set() # map from purge id to PurgeStatus self._purges_by_id: Dict[str, PurgeStatus] = {} + # map from purge id to DeleteStatus + self._delete_by_id: Dict[str, DeleteStatus] = {} + # map from room id to delete ids + # Dict[`room_id`, List[`delete_id`]] + self._delete_by_room: Dict[str, List[str]] = {} self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = ( @@ -265,8 +326,13 @@ class PaginationHandler: logger.info("[purge] starting purge_id %s", purge_id) self._purges_by_id[purge_id] = PurgeStatus() - run_in_background( - self._purge_history, purge_id, room_id, token, delete_local_events + run_as_background_process( + "purge_history", + self._purge_history, + purge_id, + room_id, + token, + delete_local_events, ) return purge_id @@ -276,7 +342,7 @@ class PaginationHandler: """Carry out a history purge on a room. Args: - purge_id: The id for this purge + purge_id: The ID for this purge. room_id: The room to purge from token: topological token to delete events before delete_local_events: True to delete local events as well as remote ones @@ -295,6 +361,7 @@ class PaginationHandler: "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore ) self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED + self._purges_by_id[purge_id].error = f.getErrorMessage() finally: self._purges_in_progress_by_room.discard(room_id) @@ -302,7 +369,9 @@ class PaginationHandler: def clear_purge() -> None: del self._purges_by_id[purge_id] - self.hs.get_reactor().callLater(24 * 3600, clear_purge) + self.hs.get_reactor().callLater( + PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000, clear_purge + ) def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]: """Get the current status of an active purge @@ -312,8 +381,25 @@ class PaginationHandler: """ return self._purges_by_id.get(purge_id) + def get_delete_status(self, delete_id: str) -> Optional[DeleteStatus]: + """Get the current status of an active deleting + + Args: + delete_id: delete_id returned by start_shutdown_and_purge_room + """ + return self._delete_by_id.get(delete_id) + + def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]: + """Get all active delete ids by room + + Args: + room_id: room_id that is deleted + """ + return self._delete_by_room.get(room_id) + async def purge_room(self, room_id: str, force: bool = False) -> None: """Purge the given room from the database. + This function is part the delete room v1 API. Args: room_id: room to be purged @@ -424,7 +510,7 @@ class PaginationHandler: if events: if event_filter: - events = event_filter.filter(events) + events = await event_filter.filter(events) events = await filter_events_for_client( self.storage, user_id, events, is_peeking=(member_event_id is None) @@ -472,3 +558,192 @@ class PaginationHandler: ) return chunk + + async def _shutdown_and_purge_room( + self, + delete_id: str, + room_id: str, + requester_user_id: str, + new_room_user_id: Optional[str] = None, + new_room_name: Optional[str] = None, + message: Optional[str] = None, + block: bool = False, + purge: bool = True, + force_purge: bool = False, + ) -> None: + """ + Shuts down and purges a room. + + See `RoomShutdownHandler.shutdown_room` for details of creation of the new room + + Args: + delete_id: The ID for this delete. + room_id: The ID of the room to shut down. + requester_user_id: + User who requested the action. Will be recorded as putting the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + purge: + If set to `true`, purge the given room from the database. + force_purge: + If set to `true`, the room will be purged from database + also if it fails to remove some users from room. + + Saves a `RoomShutdownHandler.ShutdownRoomResponse` in `DeleteStatus`: + """ + + self._purges_in_progress_by_room.add(room_id) + try: + with await self.pagination_lock.write(room_id): + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN + self._delete_by_id[ + delete_id + ].shutdown_room = await self._room_shutdown_handler.shutdown_room( + room_id=room_id, + requester_user_id=requester_user_id, + new_room_user_id=new_room_user_id, + new_room_name=new_room_name, + message=message, + block=block, + ) + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_PURGING + + if purge: + logger.info("starting purge room_id %s", room_id) + + # first check that we have no users in this room + if not force_purge: + joined = await self.store.is_host_joined( + room_id, self._server_name + ) + if joined: + raise SynapseError( + 400, "Users are still joined to this room" + ) + + await self.storage.purge_events.purge_room(room_id) + + logger.info("complete") + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE + except Exception: + f = Failure() + logger.error( + "failed", + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + ) + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_FAILED + self._delete_by_id[delete_id].error = f.getErrorMessage() + finally: + self._purges_in_progress_by_room.discard(room_id) + + # remove the delete from the list 24 hours after it completes + def clear_delete() -> None: + del self._delete_by_id[delete_id] + self._delete_by_room[room_id].remove(delete_id) + if not self._delete_by_room[room_id]: + del self._delete_by_room[room_id] + + self.hs.get_reactor().callLater( + PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000, clear_delete + ) + + def start_shutdown_and_purge_room( + self, + room_id: str, + requester_user_id: str, + new_room_user_id: Optional[str] = None, + new_room_name: Optional[str] = None, + message: Optional[str] = None, + block: bool = False, + purge: bool = True, + force_purge: bool = False, + ) -> str: + """Start off shut down and purge on a room. + + Args: + room_id: The ID of the room to shut down. + requester_user_id: + User who requested the action and put the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + purge: + If set to `true`, purge the given room from the database. + force_purge: + If set to `true`, the room will be purged from database + also if it fails to remove some users from room. + + Returns: + unique ID for this delete transaction. + """ + if room_id in self._purges_in_progress_by_room: + raise SynapseError( + 400, "History purge already in progress for %s" % (room_id,) + ) + + # This check is double to `RoomShutdownHandler.shutdown_room` + # But here the requester get a direct response / error with HTTP request + # and do not have to check the purge status + if new_room_user_id is not None: + if not self.hs.is_mine_id(new_room_user_id): + raise SynapseError( + 400, "User must be our own: %s" % (new_room_user_id,) + ) + + delete_id = random_string(16) + + # we log the delete_id here so that it can be tied back to the + # request id in the log lines. + logger.info( + "starting shutdown room_id %s with delete_id %s", + room_id, + delete_id, + ) + + self._delete_by_id[delete_id] = DeleteStatus() + self._delete_by_room.setdefault(room_id, []).append(delete_id) + run_as_background_process( + "shutdown_and_purge_room", + self._shutdown_and_purge_room, + delete_id, + room_id, + requester_user_id, + new_room_user_id, + new_room_name, + message, + block, + purge, + force_purge, + ) + return delete_id diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a0e6a01775..448a36108e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -116,7 +116,9 @@ class RegistrationHandler: self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.registration.session_lifetime - self.access_token_lifetime = hs.config.registration.access_token_lifetime + self.refreshable_access_token_lifetime = ( + hs.config.registration.refreshable_access_token_lifetime + ) init_counters_for_auth_provider("") @@ -813,13 +815,15 @@ class RegistrationHandler: ( refresh_token, refresh_token_id, - ) = await self._auth_handler.get_refresh_token_for_user_id( + ) = await self._auth_handler.create_refresh_token_for_user_id( user_id, device_id=registered_device_id, ) - valid_until_ms = self.clock.time_msec() + self.access_token_lifetime + valid_until_ms = ( + self.clock.time_msec() + self.refreshable_access_token_lifetime + ) - access_token = await self._auth_handler.get_access_token_for_user_id( + access_token = await self._auth_handler.create_access_token_for_user_id( user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 969eb3b9b0..88053f9869 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Contains functions for performing events on rooms.""" - +"""Contains functions for performing actions on rooms.""" import itertools import logging import math @@ -31,6 +30,8 @@ from typing import ( Tuple, ) +from typing_extensions import TypedDict + from synapse.api.constants import ( EventContentFields, EventTypes, @@ -774,8 +775,11 @@ class RoomCreationHandler: raise SynapseError(403, "Room visibility value not allowed.") if is_public: + room_aliases = [] + if room_alias: + room_aliases.append(room_alias.to_string()) if not self.config.roomdirectory.is_publishing_room_allowed( - user_id, room_id, room_alias + user_id, room_id, room_aliases ): # Let's just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages @@ -1158,8 +1162,10 @@ class RoomContextHandler: ) if event_filter: - results["events_before"] = event_filter.filter(results["events_before"]) - results["events_after"] = event_filter.filter(results["events_after"]) + results["events_before"] = await event_filter.filter( + results["events_before"] + ) + results["events_after"] = await event_filter.filter(results["events_after"]) results["events_before"] = await filter_evts(results["events_before"]) results["events_after"] = await filter_evts(results["events_after"]) @@ -1195,7 +1201,7 @@ class RoomContextHandler: state_events = list(state[last_event_id].values()) if event_filter: - state_events = event_filter.filter(state_events) + state_events = await event_filter.filter(state_events) results["state"] = await filter_evts(state_events) @@ -1275,8 +1281,25 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): return self.store.get_room_events_max_id(room_id) -class RoomShutdownHandler: +class ShutdownRoomResponse(TypedDict): + """ + Attributes: + kicked_users: An array of users (`user_id`) that were kicked. + failed_to_kick_users: + An array of users (`user_id`) that that were not kicked. + local_aliases: + An array of strings representing the local aliases that were + migrated from the old room to the new. + new_room_id: A string representing the room ID of the new room. + """ + kicked_users: List[str] + failed_to_kick_users: List[str] + local_aliases: List[str] + new_room_id: Optional[str] + + +class RoomShutdownHandler: DEFAULT_MESSAGE = ( "Sharing illegal content on this server is not permitted and rooms in" " violation will be blocked." @@ -1289,7 +1312,6 @@ class RoomShutdownHandler: self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.state = hs.get_state_handler() self.store = hs.get_datastore() async def shutdown_room( @@ -1300,7 +1322,7 @@ class RoomShutdownHandler: new_room_name: Optional[str] = None, message: Optional[str] = None, block: bool = False, - ) -> dict: + ) -> ShutdownRoomResponse: """ Shuts down a room. Moves all local users and room aliases automatically to a new room if `new_room_user_id` is set. Otherwise local users only @@ -1334,8 +1356,13 @@ class RoomShutdownHandler: Defaults to `Sharing illegal content on this server is not permitted and rooms in violation will be blocked.` block: - If set to `true`, this room will be added to a blocking list, - preventing future attempts to join the room. Defaults to `false`. + If set to `True`, users will be prevented from joining the old + room. This option can also be used to pre-emptively block a room, + even if it's unknown to this homeserver. In this case, the room + will be blocked, and no further action will be taken. If `False`, + attempting to delete an unknown room is invalid. + + Defaults to `False`. Returns: a dict containing the following keys: kicked_users: An array of users (`user_id`) that were kicked. @@ -1344,7 +1371,9 @@ class RoomShutdownHandler: local_aliases: An array of strings representing the local aliases that were migrated from the old room to the new. - new_room_id: A string representing the room ID of the new room. + new_room_id: + A string representing the room ID of the new room, or None if + no such room was created. """ if not new_room_name: @@ -1355,14 +1384,28 @@ class RoomShutdownHandler: if not RoomID.is_valid(room_id): raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) - if not await self.store.get_room(room_id): - raise NotFoundError("Unknown room id %s" % (room_id,)) - - # This will work even if the room is already blocked, but that is - # desirable in case the first attempt at blocking the room failed below. + # Action the block first (even if the room doesn't exist yet) if block: + # This will work even if the room is already blocked, but that is + # desirable in case the first attempt at blocking the room failed below. await self.store.block_room(room_id, requester_user_id) + if not await self.store.get_room(room_id): + if block: + # We allow you to block an unknown room. + return { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } + else: + # But if you don't want to preventatively block another room, + # this function can't do anything useful. + raise NotFoundError( + "Cannot shut down room: unknown room id %s" % (room_id,) + ) + if new_room_user_id is not None: if not self.hs.is_mine_id(new_room_user_id): raise SynapseError( diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 0723286383..f880aa93d2 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -221,6 +221,7 @@ class RoomBatchHandler: action=membership, content=event_dict["content"], outlier=True, + historical=True, prev_event_ids=[prev_event_id_for_state_chain], # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same @@ -240,6 +241,7 @@ class RoomBatchHandler: ), event_dict, outlier=True, + historical=True, prev_event_ids=[prev_event_id_for_state_chain], # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fa3e9acc74..cac76d0221 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -269,6 +269,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content: Optional[dict] = None, require_consent: bool = True, outlier: bool = False, + historical: bool = False, ) -> Tuple[str, int]: """ Internal membership update function to get an existing event or create @@ -294,6 +295,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted + back in time around some existing events. This is used to skip + a few checks and mark the event as backfilled. Returns: Tuple of event ID and stream ordering position @@ -338,6 +342,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): auth_event_ids=auth_event_ids, require_consent=require_consent, outlier=outlier, + historical=historical, ) prev_state_ids = await context.get_prev_state_ids() @@ -434,6 +439,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): new_room: bool = False, require_consent: bool = True, outlier: bool = False, + historical: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: @@ -455,6 +461,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted + back in time around some existing events. This is used to skip + a few checks and mark the event as backfilled. prev_event_ids: The event IDs to use as the prev events auth_event_ids: The event ids to use as the auth_events for the new event. @@ -507,6 +516,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): new_room=new_room, require_consent=require_consent, outlier=outlier, + historical=historical, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, ) @@ -527,6 +537,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): new_room: bool = False, require_consent: bool = True, outlier: bool = False, + historical: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: @@ -550,6 +561,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted + back in time around some existing events. This is used to skip + a few checks and mark the event as backfilled. prev_event_ids: The event IDs to use as the prev events auth_event_ids: The event ids to use as the auth_events for the new event. @@ -677,6 +691,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content=content, require_consent=require_consent, outlier=outlier, + historical=historical, ) latest_event_ids = await self.store.get_prev_events_for_room(room_id) diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index fb26ee7ad7..8181cc0b52 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -97,7 +97,7 @@ class RoomSummaryHandler: # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. self._pagination_response_cache: ResponseCache[ - Tuple[str, bool, Optional[int], Optional[int], Optional[str]] + Tuple[str, str, bool, Optional[int], Optional[int], Optional[str]] ] = ResponseCache( hs.get_clock(), "get_room_hierarchy", @@ -282,7 +282,14 @@ class RoomSummaryHandler: # This is due to the pagination process mutating internal state, attempting # to process multiple requests for the same page will result in errors. return await self._pagination_response_cache.wrap( - (requested_room_id, suggested_only, max_depth, limit, from_token), + ( + requester, + requested_room_id, + suggested_only, + max_depth, + limit, + from_token, + ), self._get_room_hierarchy, requester, requested_room_id, diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 6e4dff8056..ab7eaab2fb 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -180,7 +180,7 @@ class SearchHandler: % (set(group_keys) - {"room_id", "sender"},), ) - search_filter = Filter(filter_dict) + search_filter = Filter(self.hs, filter_dict) # TODO: Search through left rooms too rooms = await self.store.get_rooms_for_local_user_where_membership_is( @@ -242,7 +242,7 @@ class SearchHandler: rank_map.update({r["event"].event_id: r["rank"] for r in results}) - filtered_events = search_filter.filter([r["event"] for r in results]) + filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( self.storage, user.to_string(), filtered_events @@ -292,7 +292,9 @@ class SearchHandler: rank_map.update({r["event"].event_id: r["rank"] for r in results}) - filtered_events = search_filter.filter([r["event"] for r in results]) + filtered_events = await search_filter.filter( + [r["event"] for r in results] + ) events = await filter_events_for_client( self.storage, user.to_string(), filtered_events diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 2c7c6d63a9..891435c14d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -510,7 +510,7 @@ class SyncHandler: log_kv({"limited": limited}) if potential_recents: - recents = sync_config.filter_collection.filter_room_timeline( + recents = await sync_config.filter_collection.filter_room_timeline( potential_recents ) log_kv({"recents_after_sync_filtering": len(recents)}) @@ -575,8 +575,8 @@ class SyncHandler: log_kv({"loaded_recents": len(events)}) - loaded_recents = sync_config.filter_collection.filter_room_timeline( - events + loaded_recents = ( + await sync_config.filter_collection.filter_room_timeline(events) ) log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)}) @@ -1015,7 +1015,7 @@ class SyncHandler: return { (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state( + for e in await sync_config.filter_collection.filter_room_state( list(state.values()) ) if e.type != EventTypes.Aliases # until MSC2261 or alternative solution @@ -1383,7 +1383,7 @@ class SyncHandler: sync_config.user ) - account_data_for_user = sync_config.filter_collection.filter_account_data( + account_data_for_user = await sync_config.filter_collection.filter_account_data( [ {"type": account_data_type, "content": content} for account_data_type, content in account_data.items() @@ -1448,7 +1448,7 @@ class SyncHandler: # Deduplicate the presence entries so that there's at most one per user presence = list({p.user_id: p for p in presence}.values()) - presence = sync_config.filter_collection.filter_presence(presence) + presence = await sync_config.filter_collection.filter_presence(presence) sync_result_builder.presence = presence @@ -2021,12 +2021,14 @@ class SyncHandler: ) account_data_events = ( - sync_config.filter_collection.filter_room_account_data( + await sync_config.filter_collection.filter_room_account_data( account_data_events ) ) - ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) + ephemeral = await sync_config.filter_collection.filter_room_ephemeral( + ephemeral + ) if not ( always_include diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 22c6174821..1676ebd057 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -90,7 +90,7 @@ class FollowerTypingHandler: self.wheel_timer = WheelTimer(bucket_size=5000) @wrap_as_background_process("typing._handle_timeouts") - def _handle_timeouts(self) -> None: + async def _handle_timeouts(self) -> None: logger.debug("Checking for typing timeouts") now = self.clock.time_msec() diff --git a/synapse/http/server.py b/synapse/http/server.py index 1af0d9a31d..91badb0b0a 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -98,7 +98,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: "Failed handle request via %r: %r", request.request_metrics.name, request, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] ) # Only respond with an error response if we haven't already started writing, @@ -150,7 +150,7 @@ def return_html_error( logger.error( "Failed handle request %r", request, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] ) else: code = HTTPStatus.INTERNAL_SERVER_ERROR @@ -159,7 +159,7 @@ def return_html_error( logger.error( "Failed handle request %r", request, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] ) if isinstance(error_template, str): diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py index af5fc407a8..478b527494 100644 --- a/synapse/logging/handlers.py +++ b/synapse/logging/handlers.py @@ -3,7 +3,7 @@ import time from logging import Handler, LogRecord from logging.handlers import MemoryHandler from threading import Thread -from typing import Optional +from typing import Optional, cast from twisted.internet.interfaces import IReactorCore @@ -56,7 +56,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler): if reactor is None: from twisted.internet import reactor as global_reactor - reactor_to_use = global_reactor # type: ignore[assignment] + reactor_to_use = cast(IReactorCore, global_reactor) else: reactor_to_use = reactor diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 91ee5c8193..ceef57ad88 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -20,10 +20,25 @@ import os import platform import threading import time -from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) import attr -from prometheus_client import Counter, Gauge, Histogram +from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric from prometheus_client.core import ( REGISTRY, CounterMetricFamily, @@ -32,6 +47,7 @@ from prometheus_client.core import ( ) from twisted.internet import reactor +from twisted.internet.base import ReactorBase from twisted.python.threadpool import ThreadPool import synapse @@ -54,7 +70,7 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") class RegistryProxy: @staticmethod - def collect(): + def collect() -> Iterable[Metric]: for metric in REGISTRY.collect(): if not metric.name.startswith("__"): yield metric @@ -74,7 +90,7 @@ class LaterGauge: ] ) - def collect(self): + def collect(self) -> Iterable[Metric]: g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) @@ -93,10 +109,10 @@ class LaterGauge: yield g - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._register() - def _register(self): + def _register(self) -> None: if self.name in all_gauges.keys(): logger.warning("%s already registered, reregistering" % (self.name,)) REGISTRY.unregister(all_gauges.pop(self.name)) @@ -105,7 +121,12 @@ class LaterGauge: all_gauges[self.name] = self -class InFlightGauge: +# `MetricsEntry` only makes sense when it is a `Protocol`, +# but `Protocol` can't be used as a `TypeVar` bound. +MetricsEntry = TypeVar("MetricsEntry") + + +class InFlightGauge(Generic[MetricsEntry]): """Tracks number of things (e.g. requests, Measure blocks, etc) in flight at any given time. @@ -115,14 +136,19 @@ class InFlightGauge: callbacks. Args: - name (str) - desc (str) - labels (list[str]) - sub_metrics (list[str]): A list of sub metrics that the callbacks - will update. + name + desc + labels + sub_metrics: A list of sub metrics that the callbacks will update. """ - def __init__(self, name, desc, labels, sub_metrics): + def __init__( + self, + name: str, + desc: str, + labels: Sequence[str], + sub_metrics: Sequence[str], + ): self.name = name self.desc = desc self.labels = labels @@ -130,19 +156,25 @@ class InFlightGauge: # Create a class which have the sub_metrics values as attributes, which # default to 0 on initialization. Used to pass to registered callbacks. - self._metrics_class = attr.make_class( + self._metrics_class: Type[MetricsEntry] = attr.make_class( "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True ) # Counts number of in flight blocks for a given set of label values - self._registrations: Dict = {} + self._registrations: Dict[ + Tuple[str, ...], Set[Callable[[MetricsEntry], None]] + ] = {} # Protects access to _registrations self._lock = threading.Lock() self._register_with_collector() - def register(self, key, callback): + def register( + self, + key: Tuple[str, ...], + callback: Callable[[MetricsEntry], None], + ) -> None: """Registers that we've entered a new block with labels `key`. `callback` gets called each time the metrics are collected. The same @@ -158,13 +190,17 @@ class InFlightGauge: with self._lock: self._registrations.setdefault(key, set()).add(callback) - def unregister(self, key, callback): + def unregister( + self, + key: Tuple[str, ...], + callback: Callable[[MetricsEntry], None], + ) -> None: """Registers that we've exited a block with labels `key`.""" with self._lock: self._registrations.setdefault(key, set()).discard(callback) - def collect(self): + def collect(self) -> Iterable[Metric]: """Called by prometheus client when it reads metrics. Note: may be called by a separate thread. @@ -200,7 +236,7 @@ class InFlightGauge: gauge.add_metric(key, getattr(metrics, name)) yield gauge - def _register_with_collector(self): + def _register_with_collector(self) -> None: if self.name in all_gauges.keys(): logger.warning("%s already registered, reregistering" % (self.name,)) REGISTRY.unregister(all_gauges.pop(self.name)) @@ -230,7 +266,7 @@ class GaugeBucketCollector: name: str, documentation: str, buckets: Iterable[float], - registry=REGISTRY, + registry: CollectorRegistry = REGISTRY, ): """ Args: @@ -257,12 +293,12 @@ class GaugeBucketCollector: registry.register(self) - def collect(self): + def collect(self) -> Iterable[Metric]: # Don't report metrics unless we've already collected some data if self._metric is not None: yield self._metric - def update_data(self, values: Iterable[float]): + def update_data(self, values: Iterable[float]) -> None: """Update the data to be reported by the metric The existing data is cleared, and each measurement in the input is assigned @@ -304,7 +340,7 @@ class GaugeBucketCollector: class CPUMetrics: - def __init__(self): + def __init__(self) -> None: ticks_per_sec = 100 try: # Try and get the system config @@ -314,7 +350,7 @@ class CPUMetrics: self.ticks_per_sec = ticks_per_sec - def collect(self): + def collect(self) -> Iterable[Metric]: if not HAVE_PROC_SELF_STAT: return @@ -364,7 +400,7 @@ gc_time = Histogram( class GCCounts: - def collect(self): + def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) for n, m in enumerate(gc.get_count()): cm.add_metric([str(n)], m) @@ -382,7 +418,7 @@ if not running_on_pypy: class PyPyGCStats: - def collect(self): + def collect(self) -> Iterable[Metric]: # @stats is a pretty-printer object with __str__() returning a nice table, # plus some fields that contain data from that table. @@ -565,7 +601,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None: class ReactorLastSeenMetric: - def collect(self): + def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily( "python_twisted_reactor_last_seen", "Seconds since the Twisted reactor was last seen", @@ -584,9 +620,12 @@ MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0) _last_gc = [0.0, 0.0, 0.0] -def runUntilCurrentTimer(reactor, func): +F = TypeVar("F", bound=Callable[..., Any]) + + +def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F: @functools.wraps(func) - def f(*args, **kwargs): + def f(*args: Any, **kwargs: Any) -> Any: now = reactor.seconds() num_pending = 0 @@ -649,7 +688,7 @@ def runUntilCurrentTimer(reactor, func): return ret - return f + return cast(F, f) try: @@ -677,5 +716,5 @@ __all__ = [ "start_http_server", "LaterGauge", "InFlightGauge", - "BucketCollector", + "GaugeBucketCollector", ] diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index bb9bcb5592..353d0a63b6 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -25,27 +25,25 @@ import math import threading from http.server import BaseHTTPRequestHandler, HTTPServer from socketserver import ThreadingMixIn -from typing import Dict, List +from typing import Any, Dict, List, Type, Union from urllib.parse import parse_qs, urlparse -from prometheus_client import REGISTRY +from prometheus_client import REGISTRY, CollectorRegistry +from prometheus_client.core import Sample from twisted.web.resource import Resource +from twisted.web.server import Request from synapse.util import caches CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" -INF = float("inf") -MINUS_INF = float("-inf") - - -def floatToGoString(d): +def floatToGoString(d: Union[int, float]) -> str: d = float(d) - if d == INF: + if d == math.inf: return "+Inf" - elif d == MINUS_INF: + elif d == -math.inf: return "-Inf" elif math.isnan(d): return "NaN" @@ -60,7 +58,7 @@ def floatToGoString(d): return s -def sample_line(line, name): +def sample_line(line: Sample, name: str) -> str: if line.labels: labelstr = "{{{0}}}".format( ",".join( @@ -82,7 +80,7 @@ def sample_line(line, name): return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) -def generate_latest(registry, emit_help=False): +def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes: # Trigger the cache metrics to be rescraped, which updates the common # metrics but do not produce metrics themselves @@ -187,7 +185,7 @@ class MetricsHandler(BaseHTTPRequestHandler): registry = REGISTRY - def do_GET(self): + def do_GET(self) -> None: registry = self.registry params = parse_qs(urlparse(self.path).query) @@ -207,11 +205,11 @@ class MetricsHandler(BaseHTTPRequestHandler): self.end_headers() self.wfile.write(output) - def log_message(self, format, *args): + def log_message(self, format: str, *args: Any) -> None: """Log nothing.""" @classmethod - def factory(cls, registry): + def factory(cls, registry: CollectorRegistry) -> Type: """Returns a dynamic MetricsHandler class tied to the passed registry. """ @@ -236,7 +234,9 @@ class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer): daemon_threads = True -def start_http_server(port, addr="", registry=REGISTRY): +def start_http_server( + port: int, addr: str = "", registry: CollectorRegistry = REGISTRY +) -> None: """Starts an HTTP server for prometheus metrics as a daemon thread""" CustomMetricsHandler = MetricsHandler.factory(registry) httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) @@ -252,10 +252,10 @@ class MetricsResource(Resource): isLeaf = True - def __init__(self, registry=REGISTRY): + def __init__(self, registry: CollectorRegistry = REGISTRY): self.registry = registry - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) response = generate_latest(self.registry) request.setHeader(b"Content-Length", str(len(response))) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 2ab599a334..53c508af91 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -15,19 +15,37 @@ import logging import threading from functools import wraps -from typing import TYPE_CHECKING, Dict, Optional, Set, Union +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + Optional, + Set, + Type, + TypeVar, + Union, + cast, +) +from prometheus_client import Metric from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import ( + ContextResourceUsage, + LoggingContext, + PreserveLoggingContext, +) from synapse.logging.opentracing import ( SynapseTags, noop_context_manager, start_active_span, ) -from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: import resource @@ -116,7 +134,7 @@ class _Collector: before they are returned. """ - def collect(self): + def collect(self) -> Iterable[Metric]: global _background_processes_active_since_last_scrape # We swap out the _background_processes set with an empty one so that @@ -144,12 +162,12 @@ REGISTRY.register(_Collector()) class _BackgroundProcess: - def __init__(self, desc, ctx): + def __init__(self, desc: str, ctx: LoggingContext): self.desc = desc self._context = ctx - self._reported_stats = None + self._reported_stats: Optional[ContextResourceUsage] = None - def update_metrics(self): + def update_metrics(self) -> None: """Updates the metrics with values from this process.""" new_stats = self._context.get_resource_usage() if self._reported_stats is None: @@ -169,7 +187,16 @@ class _BackgroundProcess: ) -def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs): +R = TypeVar("R") + + +def run_as_background_process( + desc: str, + func: Callable[..., Awaitable[Optional[R]]], + *args: Any, + bg_start_span: bool = True, + **kwargs: Any, +) -> "defer.Deferred[Optional[R]]": """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -189,11 +216,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar args: positional args for func kwargs: keyword args for func - Returns: Deferred which returns the result of func, but note that it does not - follow the synapse logcontext rules. + Returns: + Deferred which returns the result of func, or `None` if func raises. + Note that the returned Deferred does not follow the synapse logcontext + rules. """ - async def run(): + async def run() -> Optional[R]: with _bg_metrics_lock: count = _background_process_counts.get(desc, 0) _background_process_counts[desc] = count + 1 @@ -210,12 +239,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar else: ctx = noop_context_manager() with ctx: - return await maybe_awaitable(func(*args, **kwargs)) + return await func(*args, **kwargs) except Exception: logger.exception( "Background process '%s' threw an exception", desc, ) + return None finally: _background_process_in_flight_count.labels(desc).dec() @@ -225,19 +255,24 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar return defer.ensureDeferred(run()) -def wrap_as_background_process(desc): +F = TypeVar("F", bound=Callable[..., Awaitable[Optional[Any]]]) + + +def wrap_as_background_process(desc: str) -> Callable[[F], F]: """Decorator that wraps a function that gets called as a background process. - Equivalent of calling the function with `run_as_background_process` + Equivalent to calling the function with `run_as_background_process` """ - def wrap_as_background_process_inner(func): + def wrap_as_background_process_inner(func: F) -> F: @wraps(func) - def wrap_as_background_process_inner_2(*args, **kwargs): + def wrap_as_background_process_inner_2( + *args: Any, **kwargs: Any + ) -> "defer.Deferred[Optional[R]]": return run_as_background_process(desc, func, *args, **kwargs) - return wrap_as_background_process_inner_2 + return cast(F, wrap_as_background_process_inner_2) return wrap_as_background_process_inner @@ -265,7 +300,7 @@ class BackgroundProcessLoggingContext(LoggingContext): super().__init__("%s-%s" % (name, instance_id)) self._proc = _BackgroundProcess(name, self) - def start(self, rusage: "Optional[resource.struct_rusage]"): + def start(self, rusage: "Optional[resource.struct_rusage]") -> None: """Log context has started running (again).""" super().start(rusage) @@ -276,7 +311,12 @@ class BackgroundProcessLoggingContext(LoggingContext): with _bg_metrics_lock: _background_processes_active_since_last_scrape.add(self._proc) - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: """Log context has finished.""" super().__exit__(type, value, traceback) diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py index 29ab6c0229..98ed9c0829 100644 --- a/synapse/metrics/jemalloc.py +++ b/synapse/metrics/jemalloc.py @@ -16,14 +16,16 @@ import ctypes import logging import os import re -from typing import Optional +from typing import Iterable, Optional + +from prometheus_client import Metric from synapse.metrics import REGISTRY, GaugeMetricFamily logger = logging.getLogger(__name__) -def _setup_jemalloc_stats(): +def _setup_jemalloc_stats() -> None: """Checks to see if jemalloc is loaded, and hooks up a collector to record statistics exposed by jemalloc. """ @@ -135,7 +137,7 @@ def _setup_jemalloc_stats(): class JemallocCollector: """Metrics for internal jemalloc stats.""" - def collect(self): + def collect(self) -> Iterable[Metric]: _jemalloc_refresh_stats() g = GaugeMetricFamily( @@ -185,7 +187,7 @@ def _setup_jemalloc_stats(): logger.debug("Added jemalloc stats") -def setup_jemalloc_stats(): +def setup_jemalloc_stats() -> None: """Try to setup jemalloc stats, if jemalloc is loaded.""" try: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6e7f5238fe..96d7a8f2a9 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -31,11 +31,48 @@ import attr import jinja2 from twisted.internet import defer -from twisted.web.resource import IResource +from twisted.web.resource import Resource from synapse.api.errors import SynapseError from synapse.events import EventBase -from synapse.events.presence_router import PresenceRouter +from synapse.events.presence_router import ( + GET_INTERESTED_USERS_CALLBACK, + GET_USERS_FOR_STATES_CALLBACK, + PresenceRouter, +) +from synapse.events.spamcheck import ( + CHECK_EVENT_FOR_SPAM_CALLBACK, + CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK, + CHECK_REGISTRATION_FOR_SPAM_CALLBACK, + CHECK_USERNAME_FOR_SPAM_CALLBACK, + USER_MAY_CREATE_ROOM_ALIAS_CALLBACK, + USER_MAY_CREATE_ROOM_CALLBACK, + USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK, + USER_MAY_INVITE_CALLBACK, + USER_MAY_JOIN_ROOM_CALLBACK, + USER_MAY_PUBLISH_ROOM_CALLBACK, + USER_MAY_SEND_3PID_INVITE_CALLBACK, +) +from synapse.events.third_party_rules import ( + CHECK_EVENT_ALLOWED_CALLBACK, + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK, + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK, + ON_CREATE_ROOM_CALLBACK, + ON_NEW_EVENT_CALLBACK, +) +from synapse.handlers.account_validity import ( + IS_USER_EXPIRED_CALLBACK, + ON_LEGACY_ADMIN_REQUEST, + ON_LEGACY_RENEW_CALLBACK, + ON_LEGACY_SEND_MAIL_CALLBACK, + ON_USER_REGISTRATION_CALLBACK, +) +from synapse.handlers.auth import ( + CHECK_3PID_AUTH_CALLBACK, + CHECK_AUTH_CALLBACK, + ON_LOGGED_OUT_CALLBACK, + AuthHandler, +) from synapse.http.client import SimpleHttpClient from synapse.http.server import ( DirectServeHtmlResource, @@ -114,7 +151,7 @@ class ModuleApi: can register new users etc if necessary. """ - def __init__(self, hs: "HomeServer", auth_handler): + def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None: self._hs = hs # TODO: Fix this type hint once the types for the data stores have been ironed @@ -156,47 +193,121 @@ class ModuleApi: ################################################################################# # The following methods should only be called during the module's initialisation. - @property - def register_spam_checker_callbacks(self): + def register_spam_checker_callbacks( + self, + check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None, + user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None, + user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, + user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None, + user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, + user_may_create_room_with_invites: Optional[ + USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK + ] = None, + user_may_create_room_alias: Optional[ + USER_MAY_CREATE_ROOM_ALIAS_CALLBACK + ] = None, + user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None, + check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None, + check_registration_for_spam: Optional[ + CHECK_REGISTRATION_FOR_SPAM_CALLBACK + ] = None, + check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, + ) -> None: """Registers callbacks for spam checking capabilities. Added in Synapse v1.37.0. """ - return self._spam_checker.register_callbacks + return self._spam_checker.register_callbacks( + check_event_for_spam=check_event_for_spam, + user_may_join_room=user_may_join_room, + user_may_invite=user_may_invite, + user_may_send_3pid_invite=user_may_send_3pid_invite, + user_may_create_room=user_may_create_room, + user_may_create_room_with_invites=user_may_create_room_with_invites, + user_may_create_room_alias=user_may_create_room_alias, + user_may_publish_room=user_may_publish_room, + check_username_for_spam=check_username_for_spam, + check_registration_for_spam=check_registration_for_spam, + check_media_file_for_spam=check_media_file_for_spam, + ) - @property - def register_account_validity_callbacks(self): + def register_account_validity_callbacks( + self, + is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, + on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, + on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, + on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, + on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, + ) -> None: """Registers callbacks for account validity capabilities. Added in Synapse v1.39.0. """ - return self._account_validity_handler.register_account_validity_callbacks + return self._account_validity_handler.register_account_validity_callbacks( + is_user_expired=is_user_expired, + on_user_registration=on_user_registration, + on_legacy_send_mail=on_legacy_send_mail, + on_legacy_renew=on_legacy_renew, + on_legacy_admin_request=on_legacy_admin_request, + ) - @property - def register_third_party_rules_callbacks(self): + def register_third_party_rules_callbacks( + self, + check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, + on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, + check_threepid_can_be_invited: Optional[ + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + ] = None, + check_visibility_can_be_modified: Optional[ + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + ] = None, + on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + ) -> None: """Registers callbacks for third party event rules capabilities. Added in Synapse v1.39.0. """ - return self._third_party_event_rules.register_third_party_rules_callbacks + return self._third_party_event_rules.register_third_party_rules_callbacks( + check_event_allowed=check_event_allowed, + on_create_room=on_create_room, + check_threepid_can_be_invited=check_threepid_can_be_invited, + check_visibility_can_be_modified=check_visibility_can_be_modified, + on_new_event=on_new_event, + ) - @property - def register_presence_router_callbacks(self): + def register_presence_router_callbacks( + self, + get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None, + get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None, + ) -> None: """Registers callbacks for presence router capabilities. Added in Synapse v1.42.0. """ - return self._presence_router.register_presence_router_callbacks + return self._presence_router.register_presence_router_callbacks( + get_users_for_states=get_users_for_states, + get_interested_users=get_interested_users, + ) - @property - def register_password_auth_provider_callbacks(self): + def register_password_auth_provider_callbacks( + self, + check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, + on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, + auth_checkers: Optional[ + Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] + ] = None, + ) -> None: """Registers callbacks for password auth provider capabilities. Added in Synapse v1.46.0. """ - return self._password_auth_provider.register_password_auth_provider_callbacks + return self._password_auth_provider.register_password_auth_provider_callbacks( + check_3pid_auth=check_3pid_auth, + on_logged_out=on_logged_out, + auth_checkers=auth_checkers, + ) - def register_web_resource(self, path: str, resource: IResource): + def register_web_resource(self, path: str, resource: Resource): """Registers a web resource to be served at the given path. This function should be called during initialisation of the module. @@ -216,7 +327,7 @@ class ModuleApi: # The following methods can be called by the module at any point in time. @property - def http_client(self): + def http_client(self) -> SimpleHttpClient: """Allows making outbound HTTP requests to remote resources. An instance of synapse.http.client.SimpleHttpClient @@ -226,7 +337,7 @@ class ModuleApi: return self._http_client @property - def public_room_list_manager(self): + def public_room_list_manager(self) -> "PublicRoomListManager": """Allows adding to, removing from and checking the status of rooms in the public room list. @@ -309,7 +420,7 @@ class ModuleApi: """ return await self._store.is_server_admin(UserID.from_string(user_id)) - def get_qualified_user_id(self, username): + def get_qualified_user_id(self, username: str) -> str: """Qualify a user id, if necessary Takes a user id provided by the user and adds the @ and :domain to @@ -318,7 +429,7 @@ class ModuleApi: Added in Synapse v0.25.0. Args: - username (str): provided user id + username: provided user id Returns: str: qualified @user:id @@ -357,13 +468,13 @@ class ModuleApi: """ return await self._store.user_get_threepids(user_id) - def check_user_exists(self, user_id): + def check_user_exists(self, user_id: str): """Check if user exists. Added in Synapse v0.25.0. Args: - user_id (str): Complete @user:id + user_id: Complete @user:id Returns: Deferred[str|None]: Canonical (case-corrected) user_id, or None @@ -903,7 +1014,7 @@ class ModuleApi: A list containing the loaded templates, with the orders matching the one of the filenames parameter. """ - return self._hs.config.read_templates( + return self._hs.config.server.read_templates( filenames, (td for td in (self.custom_template_dir, custom_template_directory) if td), ) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 55326877fd..a9d85f4f6c 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING from prometheus_client import Counter -from twisted.internet.protocol import Factory +from twisted.internet.protocol import ServerFactory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import PositionCommand @@ -38,7 +38,7 @@ stream_updates_counter = Counter( logger = logging.getLogger(__name__) -class ReplicationStreamProtocolFactory(Factory): +class ReplicationStreamProtocolFactory(ServerFactory): """Factory for new replication connections.""" def __init__(self, hs: "HomeServer"): diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index e04af705eb..cebdeecb81 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from synapse.http.server import HttpServer, JsonResource from synapse.rest import admin @@ -62,6 +62,8 @@ from synapse.rest.client import ( if TYPE_CHECKING: from synapse.server import HomeServer +RegisterServletsFunc = Callable[["HomeServer", HttpServer], None] + class ClientRestResource(JsonResource): """Matrix Client API REST resource. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 81e98f81d6..ee4a5e481b 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -28,6 +28,7 @@ from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin.background_updates import ( BackgroundUpdateEnabledRestServlet, BackgroundUpdateRestServlet, + BackgroundUpdateStartJobRestServlet, ) from synapse.rest.admin.devices import ( DeleteDevicesRestServlet, @@ -46,6 +47,9 @@ from synapse.rest.admin.registration_tokens import ( RegistrationTokenRestServlet, ) from synapse.rest.admin.rooms import ( + BlockRoomRestServlet, + DeleteRoomStatusByDeleteIdRestServlet, + DeleteRoomStatusByRoomIdRestServlet, ForwardExtremitiesRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, @@ -53,6 +57,7 @@ from synapse.rest.admin.rooms import ( RoomEventContextServlet, RoomMembersRestServlet, RoomRestServlet, + RoomRestV2Servlet, RoomStateRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -220,10 +225,14 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: Register all the admin servlets. """ register_servlets_for_client_rest_resource(hs, http_server) + BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) + RoomRestV2Servlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) + DeleteRoomStatusByDeleteIdRestServlet(hs).register(http_server) + DeleteRoomStatusByRoomIdRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) @@ -253,6 +262,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SendServerNoticeServlet(hs).register(http_server) BackgroundUpdateEnabledRestServlet(hs).register(http_server) BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index 0d0183bf20..479672d4d5 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_user_is_admin from synapse.types import JsonDict @@ -29,37 +34,36 @@ logger = logging.getLogger(__name__) class BackgroundUpdateEnabledRestServlet(RestServlet): """Allows temporarily disabling background updates""" - PATTERNS = admin_patterns("/background_updates/enabled") + PATTERNS = admin_patterns("/background_updates/enabled$") def __init__(self, hs: "HomeServer"): - self.group_server = hs.get_groups_server_handler() - self.is_mine_id = hs.is_mine_id - self.auth = hs.get_auth() - - self.data_stores = hs.get_datastores() + self._auth = hs.get_auth() + self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) - enabled = all(db.updates.enabled for db in self.data_stores.databases) + enabled = all(db.updates.enabled for db in self._data_stores.databases) - return 200, {"enabled": enabled} + return HTTPStatus.OK, {"enabled": enabled} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) body = parse_json_object_from_request(request) enabled = body.get("enabled", True) if not isinstance(enabled, bool): - raise SynapseError(400, "'enabled' parameter must be a boolean") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'enabled' parameter must be a boolean" + ) - for db in self.data_stores.databases: + for db in self._data_stores.databases: db.updates.enabled = enabled # If we're re-enabling them ensure that we start the background @@ -67,32 +71,29 @@ class BackgroundUpdateEnabledRestServlet(RestServlet): if enabled: db.updates.start_doing_background_updates() - return 200, {"enabled": enabled} + return HTTPStatus.OK, {"enabled": enabled} class BackgroundUpdateRestServlet(RestServlet): """Fetch information about background updates""" - PATTERNS = admin_patterns("/background_updates/status") + PATTERNS = admin_patterns("/background_updates/status$") def __init__(self, hs: "HomeServer"): - self.group_server = hs.get_groups_server_handler() - self.is_mine_id = hs.is_mine_id - self.auth = hs.get_auth() - - self.data_stores = hs.get_datastores() + self._auth = hs.get_auth() + self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) - enabled = all(db.updates.enabled for db in self.data_stores.databases) + enabled = all(db.updates.enabled for db in self._data_stores.databases) current_updates = {} - for db in self.data_stores.databases: + for db in self._data_stores.databases: update = db.updates.get_current_update() if not update: continue @@ -104,4 +105,72 @@ class BackgroundUpdateRestServlet(RestServlet): "average_items_per_ms": update.average_items_per_ms(), } - return 200, {"enabled": enabled, "current_updates": current_updates} + return HTTPStatus.OK, {"enabled": enabled, "current_updates": current_updates} + + +class BackgroundUpdateStartJobRestServlet(RestServlet): + """Allows to start specific background updates""" + + PATTERNS = admin_patterns("/background_updates/start_job") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["job_name"]) + + job_name = body["job_name"] + + if job_name == "populate_stats_process_rooms": + jobs = [ + { + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + }, + ] + elif job_name == "regenerate_directory": + jobs = [ + { + "update_name": "populate_user_directory_createtables", + "progress_json": "{}", + "depends_on": "", + }, + { + "update_name": "populate_user_directory_process_rooms", + "progress_json": "{}", + "depends_on": "populate_user_directory_createtables", + }, + { + "update_name": "populate_user_directory_process_users", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_rooms", + }, + { + "update_name": "populate_user_directory_cleanup", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_users", + }, + ] + else: + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid job_name") + + try: + await self._store.db_pool.simple_insert_many( + table="background_updates", + values=jobs, + desc=f"admin_api_run_{job_name}", + ) + except self._store.db_pool.engine.module.IntegrityError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Job %s is already in queue of background updates." % (job_name,), + ) + + self._store.db_pool.updates.start_doing_background_updates() + + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 05c5b4bf0c..a89dda1ba5 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from urllib import parse as urlparse from synapse.api.constants import EventTypes, JoinRules, Membership @@ -34,7 +34,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, RoomID, UserID, create_requester from synapse.util import json_decoder if TYPE_CHECKING: @@ -46,6 +46,138 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class RoomRestV2Servlet(RestServlet): + """Delete a room from server asynchronously with a background task. + + It is a combination and improvement of shutdown and purge room. + + Shuts down a room by removing all local users from the room. + Blocking all future invites and joins to the room is optional. + + If desired any local aliases will be repointed to a new room + created by `new_room_user_id` and kicked users will be auto- + joined to the new room. + + If 'purge' is true, it will remove all traces of a room from the database. + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self._pagination_handler = hs.get_pagination_handler() + + async def on_DELETE( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) + + content = parse_json_object_from_request(request) + + block = content.get("block", False) + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean, if given", + Codes.BAD_JSON, + ) + + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + force_purge = content.get("force_purge", False) + if not isinstance(force_purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'force_purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + if not RoomID.is_valid(room_id): + raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + + if not await self._store.get_room(room_id): + raise NotFoundError("Unknown room id %s" % (room_id,)) + + delete_id = self._pagination_handler.start_shutdown_and_purge_room( + room_id=room_id, + new_room_user_id=content.get("new_room_user_id"), + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=block, + purge=purge, + force_purge=force_purge, + ) + + return 200, {"delete_id": delete_id} + + +class DeleteRoomStatusByRoomIdRestServlet(RestServlet): + """Get the status of the delete room background task.""" + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete_status$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._pagination_handler = hs.get_pagination_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + + await assert_requester_is_admin(self._auth, request) + + if not RoomID.is_valid(room_id): + raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + + delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id) + if delete_ids is None: + raise NotFoundError("No delete task for room_id '%s' found" % room_id) + + response = [] + for delete_id in delete_ids: + delete = self._pagination_handler.get_delete_status(delete_id) + if delete: + response += [ + { + "delete_id": delete_id, + **delete.asdict(), + } + ] + return 200, {"results": cast(JsonDict, response)} + + +class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): + """Get the status of the delete room background task.""" + + PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]+)$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._pagination_handler = hs.get_pagination_handler() + + async def on_GET( + self, request: SynapseRequest, delete_id: str + ) -> Tuple[int, JsonDict]: + + await assert_requester_is_admin(self._auth, request) + + delete_status = self._pagination_handler.get_delete_status(delete_id) + if delete_status is None: + raise NotFoundError("delete id '%s' not found" % delete_id) + + return 200, cast(JsonDict, delete_status.asdict()) + + class ListRoomRestServlet(RestServlet): """ List all rooms that are known to the homeserver. Results are returned @@ -239,9 +371,22 @@ class RoomRestServlet(RestServlet): # Purge room if purge: - await pagination_handler.purge_room(room_id, force=force_purge) - - return 200, ret + try: + await pagination_handler.purge_room(room_id, force=force_purge) + except NotFoundError: + if block: + # We can block unknown rooms with this endpoint, in which case + # a failed purge is expected. + pass + else: + # But otherwise, we expect this purge to have succeeded. + raise + + # Cast safety: cast away the knowledge that this is a TypedDict. + # See https://github.com/python/mypy/issues/4976#issuecomment-579883622 + # for some discussion on why this is necessary. Either way, + # `ret` is an opaque dictionary blob as far as the rest of the app cares. + return 200, cast(JsonDict, ret) class RoomMembersRestServlet(RestServlet): @@ -303,7 +448,7 @@ class RoomStateRestServlet(RestServlet): now, # We don't bother bundling aggregations in when asked for state # events, as clients won't use them. - bundle_aggregations=False, + bundle_relations=False, ) ret = {"state": room_state} @@ -583,6 +728,7 @@ class RoomEventContextServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() @@ -600,7 +746,9 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + event_filter: Optional[Filter] = Filter( + self._hs, json_decoder.decode(filter_json) + ) else: event_filter = None @@ -630,7 +778,70 @@ class RoomEventContextServlet(RestServlet): results["state"], time_now, # No need to bundle aggregations for state events - bundle_aggregations=False, + bundle_relations=False, ) return 200, results + + +class BlockRoomRestServlet(RestServlet): + """ + Manage blocking of rooms. + On PUT: Add or remove a room from blocking list. + On GET: Get blocking status of room and user who has blocked this room. + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + blocked_by = await self._store.room_is_blocked_by(room_id) + # Test `not None` if `user_id` is an empty string + # if someone add manually an entry in database + if blocked_by is not None: + response = {"block": True, "user_id": blocked_by} + else: + response = {"block": False} + + return HTTPStatus.OK, response + + async def on_PUT( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) + + content = parse_json_object_from_request(request) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + assert_params_in_dict(content, ["block"]) + block = content.get("block") + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean.", + Codes.BAD_JSON, + ) + + if block: + await self._store.block_room(room_id, requester.user.to_string()) + else: + await self._store.unblock_room(room_id) + + return HTTPStatus.OK, {"block": block} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index d14fafbbc9..ccd9a2a175 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -898,7 +898,7 @@ class UserTokenRestServlet(RestServlet): if auth_user.to_string() == user_id: raise SynapseError(400, "Cannot use admin API to login as self") - token = await self.auth_handler.get_access_token_for_user_id( + token = await self.auth_handler.create_access_token_for_user_id( user_id=auth_user.to_string(), device_id=None, valid_until_ms=valid_until_ms, @@ -909,7 +909,7 @@ class UserTokenRestServlet(RestServlet): class ShadowBanRestServlet(RestServlet): - """An admin API for shadow-banning a user. + """An admin API for controlling whether a user is shadow-banned. A shadow-banned users receives successful responses to their client-server API requests, but the events are not propagated into rooms. @@ -917,13 +917,21 @@ class ShadowBanRestServlet(RestServlet): Shadow-banning a user should be used as a tool of last resort and may lead to confusing or broken behaviour for the client. - Example: + Example of shadow-banning a user: POST /_synapse/admin/v1/users/@test:example.com/shadow_ban {} 200 OK {} + + Example of removing a user from being shadow-banned: + + DELETE /_synapse/admin/v1/users/@test:example.com/shadow_ban + {} + + 200 OK + {} """ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban") @@ -945,6 +953,18 @@ class ShadowBanRestServlet(RestServlet): return 200, {} + async def on_DELETE( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Only local users can be shadow-banned") + + await self.store.set_shadow_banned(UserID.from_string(user_id), False) + + return 200, {} + class RateLimitRestServlet(RestServlet): """An admin API to override ratelimiting for an user. diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py index a0971ce994..b4cb90cb76 100644 --- a/synapse/rest/client/_base.py +++ b/synapse/rest/client/_base.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) def client_patterns( path_regex: str, - releases: Iterable[int] = (0,), + releases: Iterable[str] = ("r0", "v3"), unstable: bool = True, v1: bool = False, ) -> Iterable[Pattern]: @@ -52,7 +52,7 @@ def client_patterns( v1_prefix = CLIENT_API_PREFIX + "/api/v1" patterns.append(re.compile("^" + v1_prefix + path_regex)) for release in releases: - new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) + new_prefix = CLIENT_API_PREFIX + f"/{release}" patterns.append(re.compile("^" + new_prefix + path_regex)) return patterns diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 7281b2ee29..730c18f08f 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -262,7 +262,7 @@ class SigningKeyUploadServlet(RestServlet): } """ - PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) + PATTERNS = client_patterns("/keys/device_signing/upload$", releases=("v3",)) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index d49a647b03..67e03dca04 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -61,7 +61,8 @@ class LoginRestServlet(RestServlet): TOKEN_TYPE = "m.login.token" JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" - APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" + APPSERVICE_TYPE = "m.login.application_service" + APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service" REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" def __init__(self, hs: "HomeServer"): @@ -71,6 +72,7 @@ class LoginRestServlet(RestServlet): # JWT configuration variables. self.jwt_enabled = hs.config.jwt.jwt_enabled self.jwt_secret = hs.config.jwt.jwt_secret + self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim self.jwt_algorithm = hs.config.jwt.jwt_algorithm self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_audiences = hs.config.jwt.jwt_audiences @@ -79,7 +81,9 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None + self._msc2918_enabled = ( + hs.config.registration.refreshable_access_token_lifetime is not None + ) self.auth = hs.get_auth() @@ -143,6 +147,7 @@ class LoginRestServlet(RestServlet): flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) + flows.append({"type": LoginRestServlet.APPSERVICE_TYPE_UNSTABLE}) return 200, {"flows": flows} @@ -159,7 +164,10 @@ class LoginRestServlet(RestServlet): should_issue_refresh_token = False try: - if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: + if login_submission["type"] in ( + LoginRestServlet.APPSERVICE_TYPE, + LoginRestServlet.APPSERVICE_TYPE_UNSTABLE, + ): appservice = self.auth.get_appservice_by_req(request) if appservice.is_rate_limited(): @@ -408,7 +416,7 @@ class LoginRestServlet(RestServlet): errcode=Codes.FORBIDDEN, ) - user = payload.get("sub", None) + user = payload.get(self.jwt_subject_claim, None) if user is None: raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) @@ -447,7 +455,9 @@ class RefreshTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.registration.access_token_lifetime + self.refreshable_access_token_lifetime = ( + hs.config.registration.refreshable_access_token_lifetime + ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -457,7 +467,9 @@ class RefreshTokenServlet(RestServlet): if not isinstance(token, str): raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) - valid_until_ms = self._clock.time_msec() + self.access_token_lifetime + valid_until_ms = ( + self._clock.time_msec() + self.refreshable_access_token_lifetime + ) access_token, refresh_token = await self._auth_handler.refresh_token( token, valid_until_ms ) @@ -556,7 +568,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) - if hs.config.registration.access_token_lifetime is not None: + if hs.config.registration.refreshable_access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas.cas_enabled: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index bf3cb34146..d2b11e39d9 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -420,7 +420,9 @@ class RegisterRestServlet(RestServlet): self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_enabled = self.hs.config.registration.enable_registration - self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None + self._msc2918_enabled = ( + hs.config.registration.refreshable_access_token_lifetime is not None + ) self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 58f6699073..45e9f1dd90 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -224,17 +224,17 @@ class RelationPaginationServlet(RestServlet): ) now = self.clock.time_msec() - # We set bundle_aggregations to False when retrieving the original + # We set bundle_relations to False when retrieving the original # event because we want the content before relations were applied to # it. original_event = await self._event_serializer.serialize_event( - event, now, bundle_aggregations=False + event, now, bundle_relations=False ) # Similarly, we don't allow relations to be applied to relations, so we # return the original relations without any aggregations on top of them # here. serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=False + events, now, bundle_relations=False ) return_value = pagination_chunk.to_dict() @@ -298,7 +298,9 @@ class RelationAggregationPaginationServlet(RestServlet): raise SynapseError(404, "Unknown parent event.") if relation_type not in (RelationTypes.ANNOTATION, None): - raise SynapseError(400, "Relation type must be 'annotation'") + raise SynapseError( + 400, f"Relation type must be '{RelationTypes.ANNOTATION}'" + ) limit = parse_integer(request, "limit", default=5) from_token_str = parse_string(request, "from") diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 6a876cfa2f..955d4e8641 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -550,6 +550,7 @@ class RoomMessageListRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -567,7 +568,9 @@ class RoomMessageListRestServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + event_filter: Optional[Filter] = Filter( + self._hs, json_decoder.decode(filter_json) + ) if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -672,6 +675,7 @@ class RoomEventContextServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() @@ -688,7 +692,9 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + event_filter: Optional[Filter] = Filter( + self._hs, json_decoder.decode(filter_json) + ) else: event_filter = None @@ -713,7 +719,7 @@ class RoomEventContextServlet(RestServlet): results["state"], time_now, # No need to bundle aggregations for state events - bundle_aggregations=False, + bundle_relations=False, ) return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 913216a7c4..b6a2485732 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -29,7 +29,7 @@ from typing import ( from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection +from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.events.utils import ( @@ -150,7 +150,7 @@ class SyncRestServlet(RestServlet): request_key = (user, timeout, since, filter_id, full_state, device_id) if filter_id is None: - filter_collection = DEFAULT_FILTER_COLLECTION + filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION elif filter_id.startswith("{"): try: filter_object = json_decoder.decode(filter_id) @@ -160,7 +160,7 @@ class SyncRestServlet(RestServlet): except Exception: raise SynapseError(400, "Invalid filter JSON") self.filtering.check_valid_filter(filter_object) - filter_collection = FilterCollection(filter_object) + filter_collection = FilterCollection(self.hs, filter_object) else: try: filter_collection = await self.filtering.get_user_filter( @@ -522,7 +522,7 @@ class SyncRestServlet(RestServlet): time_now=time_now, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. - bundle_aggregations=False, + bundle_relations=False, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 8ca97b5b18..054f3c296d 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -231,7 +231,7 @@ class PreviewUrlResource(DirectServeJsonResource): og = await make_deferred_yieldable(observable.observe()) respond_with_json_bytes(request, 200, og, send_cors=True) - async def _do_preview(self, url: str, user: str, ts: int) -> bytes: + async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes: """Check the db, and download the URL and build a preview Args: @@ -360,7 +360,7 @@ class PreviewUrlResource(DirectServeJsonResource): return jsonog.encode("utf8") - async def _download_url(self, url: str, user: str) -> MediaInfo: + async def _download_url(self, url: str, user: UserID) -> MediaInfo: # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource): ) async def _precache_image_url( - self, user: str, media_info: MediaInfo, og: JsonDict + self, user: UserID, media_info: MediaInfo, og: JsonDict ) -> None: """ Pre-cache the image (if one exists) for posterity diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 46701a8b83..5e17664b5b 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -101,8 +101,8 @@ class Thumbnailer: fits within the given rectangle:: (w_in / h_in) = (w_out / h_out) - w_out = min(w_max, h_max * (w_in / h_in)) - h_out = min(h_max, w_max * (h_in / w_in)) + w_out = max(min(w_max, h_max * (w_in / h_in)), 1) + h_out = max(min(h_max, w_max * (h_in / w_in)), 1) Args: max_width: The largest possible width. @@ -110,9 +110,9 @@ class Thumbnailer: """ if max_width * self.height < max_height * self.width: - return max_width, (max_width * self.height) // self.width + return max_width, max((max_width * self.height) // self.width, 1) else: - return (max_height * self.width) // self.height, max_height + return max((max_height * self.width) // self.height, 1), max_height def _resize(self, width: int, height: int) -> Image.Image: # 1-bit or 8-bit color palette images need converting to RGB diff --git a/synapse/server.py b/synapse/server.py index 013a7bacaa..877eba6c08 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -33,9 +33,10 @@ from typing import ( cast, ) -import twisted.internet.tcp +from twisted.internet.interfaces import IOpenSSLContextFactory +from twisted.internet.tcp import Port from twisted.web.iweb import IPolicyForHTTPS -from twisted.web.resource import IResource +from twisted.web.resource import Resource from synapse.api.auth import Auth from synapse.api.filtering import Filtering @@ -206,7 +207,7 @@ class HomeServer(metaclass=abc.ABCMeta): Attributes: config (synapse.config.homeserver.HomeserverConfig): - _listening_services (list[twisted.internet.tcp.Port]): TCP ports that + _listening_services (list[Port]): TCP ports that we are listening on to provide HTTP services. """ @@ -225,6 +226,8 @@ class HomeServer(metaclass=abc.ABCMeta): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() + tls_server_context_factory: Optional[IOpenSSLContextFactory] + def __init__( self, hostname: str, @@ -247,7 +250,7 @@ class HomeServer(metaclass=abc.ABCMeta): # the key we use to sign events and requests self.signing_key = config.key.signing_key[0] self.config = config - self._listening_services: List[twisted.internet.tcp.Port] = [] + self._listening_services: List[Port] = [] self.start_time: Optional[int] = None self._instance_id = random_string(5) @@ -257,10 +260,10 @@ class HomeServer(metaclass=abc.ABCMeta): self.datastores: Optional[Databases] = None - self._module_web_resources: Dict[str, IResource] = {} + self._module_web_resources: Dict[str, Resource] = {} self._module_web_resources_consumed = False - def register_module_web_resource(self, path: str, resource: IResource): + def register_module_web_resource(self, path: str, resource: Resource): """Allows a module to register a web resource to be served at the given path. If multiple modules register a resource for the same path, the module that diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index b9a8ca997e..bc8364400d 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -82,7 +82,7 @@ class BackgroundUpdater: process and autotuning the batch size. """ - MINIMUM_BACKGROUND_BATCH_SIZE = 100 + MINIMUM_BACKGROUND_BATCH_SIZE = 1 DEFAULT_BACKGROUND_BATCH_SIZE = 100 BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 @@ -122,6 +122,8 @@ class BackgroundUpdater: def start_doing_background_updates(self) -> None: if self.enabled: + # if we start a new background update, not all updates are done. + self._all_done = False run_as_background_process("background_updates", self.run_background_updates) async def run_background_updates(self, sleep: bool = True) -> None: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index d4cab69ebf..0693d39006 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -188,7 +188,7 @@ class LoggingDatabaseConnection: # The type of entry which goes on our after_callbacks and exception_callbacks lists. -_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]] +_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]] R = TypeVar("R") @@ -235,7 +235,7 @@ class LoggingTransaction: self.after_callbacks = after_callbacks self.exception_callbacks = exception_callbacks - def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any): + def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. @@ -247,7 +247,7 @@ class LoggingTransaction: self.after_callbacks.append((callback, args, kwargs)) def call_on_exception( - self, callback: Callable[..., None], *args: Any, **kwargs: Any + self, callback: Callable[..., object], *args: Any, **kwargs: Any ): # if self.exception_callbacks is None, that means that whatever constructed the # LoggingTransaction isn't expecting there to be any callbacks; assert that diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 259cae5b37..9ff2d8d8c3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -123,9 +123,9 @@ class DataStore( RelationsStore, CensorEventsStore, UIAuthStore, + EventForwardExtremitiesStore, CacheInvalidationWorkerStore, ServerMetricsStore, - EventForwardExtremitiesStore, LockStore, SessionStore, ): @@ -154,6 +154,7 @@ class DataStore( db_conn, "local_group_updates", "stream_id" ) + self._cache_id_gen: Optional[MultiWriterIdGenerator] if isinstance(self.database_engine, PostgresEngine): # We set the `writers` to an empty list here as we don't care about # missing updates over restarts, as we'll not have anything in our diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 2da2659f41..baec35ee27 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -412,16 +412,16 @@ class ApplicationServiceTransactionWorkerStore( ) async def set_type_stream_id_for_appservice( - self, service: ApplicationService, type: str, pos: Optional[int] + self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if type not in ("read_receipt", "presence"): + if stream_type not in ("read_receipt", "presence"): raise ValueError( "Expected type to be a valid application stream id type, got %s" - % (type,) + % (stream_type,) ) def set_type_stream_id_for_appservice_txn(txn): - stream_id_type = "%s_stream_id" % type + stream_id_type = "%s_stream_id" % stream_type txn.execute( "UPDATE application_services_state SET %s = ? WHERE as_id=?" % stream_id_type, diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index eee07227ef..0f56e10220 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -13,12 +13,12 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.util import json_encoder @@ -41,7 +41,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) @wrap_as_background_process("_censor_redactions") - async def _censor_redactions(self): + async def _censor_redactions(self) -> None: """Censors all redactions older than the configured period that haven't been censored yet. @@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = json_encoder.encode( + pruned_json: Optional[str] = json_encoder.encode( prune_event_dict( original_event.room_version, original_event.get_dict() ) @@ -116,7 +116,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase updates.append((redaction_id, event_id, pruned_json)) - def _update_censor_txn(txn): + def _update_censor_txn(txn: LoggingTransaction) -> None: for redaction_id, event_id, pruned_json in updates: if pruned_json: self._censor_event_txn(txn, event_id, pruned_json) @@ -130,14 +130,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn) - def _censor_event_txn(self, txn, event_id, pruned_json): + def _censor_event_txn( + self, txn: LoggingTransaction, event_id: str, pruned_json: str + ) -> None: """Censor an event by replacing its JSON in the event_json table with the provided pruned JSON. Args: - txn (LoggingTransaction): The database transaction. - event_id (str): The ID of the event to censor. - pruned_json (str): The pruned JSON + txn: The database transaction. + event_id: The ID of the event to censor. + pruned_json: The pruned JSON """ self.db_pool.simple_update_one_txn( txn, @@ -157,7 +159,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase # Try to retrieve the event's content from the database or the event cache. event = await self.get_event(event_id) - def delete_expired_event_txn(txn): + def delete_expired_event_txn(txn: LoggingTransaction) -> None: # Delete the expiry timestamp associated with this event from the database. self._delete_event_expiry_txn(txn, event_id) @@ -194,14 +196,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase "delete_expired_event", delete_expired_event_txn ) - def _delete_event_expiry_txn(self, txn, event_id): + def _delete_event_expiry_txn(self, txn: LoggingTransaction, event_id: str) -> None: """Delete the expiry timestamp associated with an event ID without deleting the actual event. Args: - txn (LoggingTransaction): The transaction to use to perform the deletion. - event_id (str): The event ID to delete the associated expiry timestamp of. + txn: The transaction to use to perform the deletion. + event_id: The event ID to delete the associated expiry timestamp of. """ - return self.db_pool.simple_delete_txn( + self.db_pool.simple_delete_txn( txn=txn, table="event_expiry", keyvalues={"event_id": event_id} ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 264e625bd7..ab8766c75b 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -1,4 +1,5 @@ # Copyright 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,9 +20,17 @@ from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.replication.tcp.streams import ToDeviceStream from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache @@ -34,14 +43,21 @@ logger = logging.getLogger(__name__) class DeviceInboxWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. - self._last_device_delete_cache = ExpiringCache( + self._last_device_delete_cache: ExpiringCache[ + Tuple[str, Optional[str]], int + ] = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, @@ -53,14 +69,16 @@ class DeviceInboxWorkerStore(SQLBaseStore): self._instance_name in hs.config.worker.writers.to_device ) - self._device_inbox_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - stream_name="to_device", - instance_name=self._instance_name, - tables=[("device_inbox", "instance_name", "stream_id")], - sequence_name="device_inbox_sequence", - writers=hs.config.worker.writers.to_device, + self._device_inbox_id_gen: AbstractStreamIdGenerator = ( + MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="to_device", + instance_name=self._instance_name, + tables=[("device_inbox", "instance_name", "stream_id")], + sequence_name="device_inbox_sequence", + writers=hs.config.worker.writers.to_device, + ) ) else: self._can_write_to_device = True @@ -101,6 +119,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ToDeviceStream.NAME: + # If replication is happening than postgres must be being used. + assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator) self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): @@ -134,7 +154,10 @@ class DeviceInboxWorkerStore(SQLBaseStore): limit: The maximum number of messages to retrieve. Returns: - A list of messages for the device and where in the stream the messages got to. + A tuple containing: + * A list of messages for the device. + * The max stream token of these messages. There may be more to retrieve + if the given limit was reached. """ has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_stream_id @@ -153,12 +176,19 @@ class DeviceInboxWorkerStore(SQLBaseStore): txn.execute( sql, (user_id, device_id, last_stream_id, current_stream_id, limit) ) + messages = [] + stream_pos = current_stream_id + for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) + + # If the limit was not reached we know that there's no more data for this + # user/device pair up to current_stream_id. if len(messages) < limit: stream_pos = current_stream_id + return messages, stream_pos return await self.db_pool.runInteraction( @@ -210,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): log_kv({"message": f"deleted {count} messages for device", "count": count}) # Update the cache, ensuring that we only ever increase the value - last_deleted_stream_id = self._last_device_delete_cache.get( + updated_last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), 0 ) self._last_device_delete_cache[(user_id, device_id)] = max( - last_deleted_stream_id, up_to_stream_id + updated_last_deleted_stream_id, up_to_stream_id ) return count @@ -260,13 +290,20 @@ class DeviceInboxWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (destination, last_stream_id, current_stream_id, limit)) + messages = [] + stream_pos = current_stream_id + for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) + + # If the limit was not reached we know that there's no more data for this + # user/device pair up to current_stream_id. if len(messages) < limit: log_kv({"message": "Set stream position to current position"}) stream_pos = current_stream_id + return messages, stream_pos return await self.db_pool.runInteraction( @@ -372,8 +409,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): """Used to send messages from this server. Args: - local_messages_by_user_and_device: - Dictionary of user_id to device_id to message. + local_messages_by_user_then_device: + Dictionary of recipient user_id to recipient device_id to message. remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. @@ -415,7 +452,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) async with self._device_inbox_id_gen.get_next() as stream_id: - now_ms = self.clock.time_msec() + now_ms = self._clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) @@ -466,7 +503,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) async with self._device_inbox_id_gen.get_next() as stream_id: - now_ms = self.clock.time_msec() + now_ms = self._clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, @@ -562,6 +599,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox" REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox" + REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox" def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) @@ -577,14 +615,18 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) - self.db_pool.updates.register_background_update_handler( - self.REMOVE_DELETED_DEVICES, - self._remove_deleted_devices_from_device_inbox, + # Used to be a background update that deletes all device_inboxes for deleted + # devices. + self.db_pool.updates.register_noop_background_update( + self.REMOVE_DELETED_DEVICES ) + # Used to be a background update that deletes all device_inboxes for hidden + # devices. + self.db_pool.updates.register_noop_background_update(self.REMOVE_HIDDEN_DEVICES) self.db_pool.updates.register_background_update_handler( - self.REMOVE_HIDDEN_DEVICES, - self._remove_hidden_devices_from_device_inbox, + self.REMOVE_DEAD_DEVICES_FROM_INBOX, + self._remove_dead_devices_from_device_inbox, ) async def _background_drop_index_device_inbox(self, progress, batch_size): @@ -599,171 +641,83 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): return 1 - async def _remove_deleted_devices_from_device_inbox( - self, progress: JsonDict, batch_size: int + async def _remove_dead_devices_from_device_inbox( + self, + progress: JsonDict, + batch_size: int, ) -> int: - """A background update that deletes all device_inboxes for deleted devices. - - This should only need to be run once (when users upgrade to v1.47.0) + """A background update to remove devices that were either deleted or hidden from + the device_inbox table. Args: - progress: JsonDict used to store progress of this background update - batch_size: the maximum number of rows to retrieve in a single select query + progress: The update's progress dict. + batch_size: The batch size for this update. Returns: - The number of deleted rows + The number of rows deleted. """ - def _remove_deleted_devices_from_device_inbox_txn( + def _remove_dead_devices_from_device_inbox_txn( txn: LoggingTransaction, - ) -> int: - """stream_id is not unique - we need to use an inclusive `stream_id >= ?` clause, - since we might not have deleted all dead device messages for the stream_id - returned from the previous query + ) -> Tuple[int, bool]: - Then delete only rows matching the `(user_id, device_id, stream_id)` tuple, - to avoid problems of deleting a large number of rows all at once - due to a single device having lots of device messages. - """ + if "max_stream_id" in progress: + max_stream_id = progress["max_stream_id"] + else: + txn.execute("SELECT max(stream_id) FROM device_inbox") + # There's a type mismatch here between how we want to type the row and + # what fetchone says it returns, but we silence it because we know that + # res can't be None. + res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment] + if res[0] is None: + # this can only happen if the `device_inbox` table is empty, in which + # case we have no work to do. + return 0, True + else: + max_stream_id = res[0] - last_stream_id = progress.get("stream_id", 0) + start = progress.get("stream_id", 0) + stop = start + batch_size + # delete rows in `device_inbox` which do *not* correspond to a known, + # unhidden device. sql = """ - SELECT device_id, user_id, stream_id - FROM device_inbox + DELETE FROM device_inbox WHERE - stream_id >= ? - AND (device_id, user_id) NOT IN ( - SELECT device_id, user_id FROM devices + stream_id >= ? AND stream_id < ? + AND NOT EXISTS ( + SELECT * FROM devices d + WHERE + d.device_id=device_inbox.device_id + AND d.user_id=device_inbox.user_id + AND NOT hidden ) - ORDER BY stream_id - LIMIT ? - """ + """ - txn.execute(sql, (last_stream_id, batch_size)) - rows = txn.fetchall() - - num_deleted = 0 - for row in rows: - num_deleted += self.db_pool.simple_delete_txn( - txn, - "device_inbox", - {"device_id": row[0], "user_id": row[1], "stream_id": row[2]}, - ) - - if rows: - # send more than stream_id to progress - # otherwise it can happen in large deployments that - # no change of status is visible in the log file - # it may be that the stream_id does not change in several runs - self.db_pool.updates._background_update_progress_txn( - txn, - self.REMOVE_DELETED_DEVICES, - { - "device_id": rows[-1][0], - "user_id": rows[-1][1], - "stream_id": rows[-1][2], - }, - ) - - return num_deleted - - number_deleted = await self.db_pool.runInteraction( - "_remove_deleted_devices_from_device_inbox", - _remove_deleted_devices_from_device_inbox_txn, - ) + txn.execute(sql, (start, stop)) - # The task is finished when no more lines are deleted. - if not number_deleted: - await self.db_pool.updates._end_background_update( - self.REMOVE_DELETED_DEVICES + self.db_pool.updates._background_update_progress_txn( + txn, + self.REMOVE_DEAD_DEVICES_FROM_INBOX, + { + "stream_id": stop, + "max_stream_id": max_stream_id, + }, ) - return number_deleted - - async def _remove_hidden_devices_from_device_inbox( - self, progress: JsonDict, batch_size: int - ) -> int: - """A background update that deletes all device_inboxes for hidden devices. - - This should only need to be run once (when users upgrade to v1.47.0) - - Args: - progress: JsonDict used to store progress of this background update - batch_size: the maximum number of rows to retrieve in a single select query - - Returns: - The number of deleted rows - """ - - def _remove_hidden_devices_from_device_inbox_txn( - txn: LoggingTransaction, - ) -> int: - """stream_id is not unique - we need to use an inclusive `stream_id >= ?` clause, - since we might not have deleted all hidden device messages for the stream_id - returned from the previous query - - Then delete only rows matching the `(user_id, device_id, stream_id)` tuple, - to avoid problems of deleting a large number of rows all at once - due to a single device having lots of device messages. - """ - - last_stream_id = progress.get("stream_id", 0) - - sql = """ - SELECT device_id, user_id, stream_id - FROM device_inbox - WHERE - stream_id >= ? - AND (device_id, user_id) IN ( - SELECT device_id, user_id FROM devices WHERE hidden = ? - ) - ORDER BY stream_id - LIMIT ? - """ - - txn.execute(sql, (last_stream_id, True, batch_size)) - rows = txn.fetchall() - - num_deleted = 0 - for row in rows: - num_deleted += self.db_pool.simple_delete_txn( - txn, - "device_inbox", - {"device_id": row[0], "user_id": row[1], "stream_id": row[2]}, - ) - - if rows: - # We don't just save the `stream_id` in progress as - # otherwise it can happen in large deployments that - # no change of status is visible in the log file, as - # it may be that the stream_id does not change in several runs - self.db_pool.updates._background_update_progress_txn( - txn, - self.REMOVE_HIDDEN_DEVICES, - { - "device_id": rows[-1][0], - "user_id": rows[-1][1], - "stream_id": rows[-1][2], - }, - ) - - return num_deleted + return stop > max_stream_id - number_deleted = await self.db_pool.runInteraction( - "_remove_hidden_devices_from_device_inbox", - _remove_hidden_devices_from_device_inbox_txn, + finished = await self.db_pool.runInteraction( + "_remove_devices_from_device_inbox_txn", + _remove_dead_devices_from_device_inbox_txn, ) - # The task is finished when no more lines are deleted. - if not number_deleted: + if finished: await self.db_pool.updates._end_background_update( - self.REMOVE_HIDDEN_DEVICES + self.REMOVE_DEAD_DEVICES_FROM_INBOX, ) - return number_deleted + return batch_size class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 6daf8b8ffb..a3442814d7 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -13,17 +13,18 @@ # limitations under the License. from collections import namedtuple -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Tuple from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) -class DirectoryWorkerStore(SQLBaseStore): +class DirectoryWorkerStore(CacheInvalidationWorkerStore): async def get_association_from_room_alias( self, room_alias: RoomAlias ) -> Optional[RoomAliasMapping]: @@ -91,7 +92,7 @@ class DirectoryWorkerStore(SQLBaseStore): creator: Optional user_id of creator. """ - def alias_txn(txn): + def alias_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_insert_txn( txn, "room_aliases", @@ -126,14 +127,16 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): - async def delete_room_alias(self, room_alias: RoomAlias) -> str: + async def delete_room_alias(self, room_alias: RoomAlias) -> Optional[str]: room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) return room_id - def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str: + def _delete_room_alias_txn( + self, txn: LoggingTransaction, room_alias: RoomAlias + ) -> Optional[str]: txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),), @@ -173,9 +176,9 @@ class DirectoryStore(DirectoryWorkerStore): If None, the creator will be left unchanged. """ - def _update_aliases_for_room_txn(txn): + def _update_aliases_for_room_txn(txn: LoggingTransaction) -> None: update_creator_sql = "" - sql_params = (new_room_id, old_room_id) + sql_params: Tuple[str, ...] = (new_room_id, old_room_id) if creator: update_creator_sql = ", creator = ?" sql_params = (new_room_id, creator, old_room_id) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index a95ac34f09..b06c1dc45b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -408,29 +408,58 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): fallback_keys: the keys to set. This is a map from key ID (which is of the form "algorithm:id") to key data. """ + await self.db_pool.runInteraction( + "set_e2e_fallback_keys_txn", + self._set_e2e_fallback_keys_txn, + user_id, + device_id, + fallback_keys, + ) + + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) + ) + + def _set_e2e_fallback_keys_txn( + self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict + ) -> None: # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad # FIXME: make sure that only one key per algorithm is uploaded for key_id, fallback_key in fallback_keys.items(): algorithm, key_id = key_id.split(":", 1) - await self.db_pool.simple_upsert( - "e2e_fallback_keys_json", + old_key_json = self.db_pool.simple_select_one_onecol_txn( + txn, + table="e2e_fallback_keys_json", keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, }, - values={ - "key_id": key_id, - "key_json": json_encoder.encode(fallback_key), - "used": False, - }, - desc="set_e2e_fallback_key", + retcol="key_json", + allow_none=True, ) - await self.invalidate_cache_and_stream( - "get_e2e_unused_fallback_key_types", (user_id, device_id) - ) + new_key_json = encode_canonical_json(fallback_key).decode("utf-8") + + # If the uploaded key is the same as the current fallback key, + # don't do anything. This prevents marking the key as unused if it + # was already used. + if old_key_json != new_key_json: + self.db_pool.simple_upsert_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + values={ + "key_id": key_id, + "key_json": json_encoder.encode(fallback_key), + "used": False, + }, + ) @cached(max_entries=10000) async def get_e2e_unused_fallback_key_types( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 596275c23c..06832221ad 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1,6 +1,6 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1641,8 +1641,8 @@ class PersistEventsStore: def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database.""" - def str_or_none(val: Any) -> Optional[str]: - return val if isinstance(val, str) else None + def non_null_str_or_none(val: Any) -> Optional[str]: + return val if isinstance(val, str) and "\u0000" not in val else None self.db_pool.simple_insert_many_txn( txn, @@ -1654,8 +1654,10 @@ class PersistEventsStore: "sender": event.user_id, "room_id": event.room_id, "membership": event.membership, - "display_name": str_or_none(event.content.get("displayname")), - "avatar_url": str_or_none(event.content.get("avatar_url")), + "display_name": non_null_str_or_none( + event.content.get("displayname") + ), + "avatar_url": non_null_str_or_none(event.content.get("avatar_url")), } for event in events ], @@ -1694,34 +1696,33 @@ class PersistEventsStore: }, ) - def _handle_event_relations(self, txn, event): - """Handles inserting relation data during peristence of events + def _handle_event_relations( + self, txn: LoggingTransaction, event: EventBase + ) -> None: + """Handles inserting relation data during persistence of events Args: - txn - event (EventBase) + txn: The current database transaction. + event: The event which might have relations. """ relation = event.content.get("m.relates_to") if not relation: # No relations return + # Relations must have a type and parent event ID. rel_type = relation.get("rel_type") - if rel_type not in ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.REPLACE, - RelationTypes.THREAD, - ): - # Unknown relation type + if not isinstance(rel_type, str): return parent_id = relation.get("event_id") - if not parent_id: - # Invalid relation + if not isinstance(parent_id, str): return - aggregation_key = relation.get("key") + # Annotations have a key field. + aggregation_key = None + if rel_type == RelationTypes.ANNOTATION: + aggregation_key = relation.get("key") self.db_pool.simple_insert_txn( txn, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index ae3a8a63e4..c88fd35e7f 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1,4 +1,4 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._purged_chain_cover_index, ) + # The event_thread_relation background update was replaced with the + # event_arbitrary_relations one, which handles any relation to avoid + # needed to potentially crawl the entire events table in the future. + self.db_pool.updates.register_noop_background_update("event_thread_relation") + self.db_pool.updates.register_background_update_handler( - "event_thread_relation", self._event_thread_relation + "event_arbitrary_relations", + self._event_arbitrary_relations, ) ################################################################################ @@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return result - async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int: - """Background update handler which will store thread relations for existing events.""" + async def _event_arbitrary_relations( + self, progress: JsonDict, batch_size: int + ) -> int: + """Background update handler which will store previously unknown relations for existing events.""" last_event_id = progress.get("last_event_id", "") - def _event_thread_relation_txn(txn: LoggingTransaction) -> int: + def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int: + # Fetch events and then filter based on whether the event has a + # relation or not. txn.execute( """ SELECT event_id, json FROM event_json - LEFT JOIN event_relations USING (event_id) - WHERE event_id > ? AND event_relations.event_id IS NULL + WHERE event_id > ? ORDER BY event_id LIMIT ? """, (last_event_id, batch_size), ) results = list(txn) - missing_thread_relations = [] + # (event_id, parent_id, rel_type) for each relation + relations_to_insert: List[Tuple[str, str, str]] = [] for (event_id, event_json_raw) in results: try: event_json = db_to_json(event_json_raw) @@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) continue - # If there's no relation (or it is not a thread), skip! + # If there's no relation, skip! relates_to = event_json["content"].get("m.relates_to") if not relates_to or not isinstance(relates_to, dict): continue - if relates_to.get("rel_type") != RelationTypes.THREAD: + + # If the relation type or parent event ID is not a string, skip it. + # + # Do not consider relation types that have existed for a long time, + # since they will already be listed in the `event_relations` table. + rel_type = relates_to.get("rel_type") + if not isinstance(rel_type, str) or rel_type in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): continue - # Get the parent ID. parent_id = relates_to.get("event_id") if not isinstance(parent_id, str): continue - missing_thread_relations.append((event_id, parent_id)) + relations_to_insert.append((event_id, parent_id, rel_type)) + + # Insert the missing data, note that we upsert here in case the event + # has already been processed. + if relations_to_insert: + self.db_pool.simple_upsert_many_txn( + txn=txn, + table="event_relations", + key_names=("event_id",), + key_values=[(r[0],) for r in relations_to_insert], + value_names=("relates_to_id", "relation_type"), + value_values=[r[1:] for r in relations_to_insert], + ) - # Insert the missing data. - self.db_pool.simple_insert_many_txn( - txn=txn, - table="event_relations", - values=[ - { - "event_id": event_id, - "relates_to_Id": parent_id, - "relation_type": RelationTypes.THREAD, - } - for event_id, parent_id in missing_thread_relations - ], - ) + # Iterate the parent IDs and invalidate caches. + for parent_id in {r[1] for r in relations_to_insert}: + cache_tuple = (parent_id,) + self._invalidate_cache_and_stream( + txn, self.get_relations_for_event, cache_tuple + ) + self._invalidate_cache_and_stream( + txn, self.get_aggregation_groups_for_event, cache_tuple + ) + self._invalidate_cache_and_stream( + txn, self.get_thread_summary, cache_tuple + ) if results: latest_event_id = results[-1][0] self.db_pool.updates._background_update_progress_txn( - txn, "event_thread_relation", {"last_event_id": latest_event_id} + txn, "event_arbitrary_relations", {"last_event_id": latest_event_id} ) return len(results) num_rows = await self.db_pool.runInteraction( - desc="event_thread_relation", func=_event_thread_relation_txn + desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn ) if not num_rows: - await self.db_pool.updates._end_background_update("event_thread_relation") + await self.db_pool.updates._end_background_update( + "event_arbitrary_relations" + ) return num_rows diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 6d2688d711..68901b4335 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -13,15 +13,20 @@ # limitations under the License. import logging -from typing import Dict, List +from typing import Any, Dict, List from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main import CacheInvalidationWorkerStore +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore logger = logging.getLogger(__name__) -class EventForwardExtremitiesStore(SQLBaseStore): +class EventForwardExtremitiesStore( + EventFederationWorkerStore, + CacheInvalidationWorkerStore, +): async def delete_forward_extremities_for_room(self, room_id: str) -> int: """Delete any extra forward extremities for a room. @@ -31,7 +36,7 @@ class EventForwardExtremitiesStore(SQLBaseStore): Returns count deleted. """ - def delete_forward_extremities_for_room_txn(txn): + def delete_forward_extremities_for_room_txn(txn: LoggingTransaction) -> int: # First we need to get the event_id to not delete sql = """ SELECT event_id FROM event_forward_extremities @@ -82,10 +87,14 @@ class EventForwardExtremitiesStore(SQLBaseStore): delete_forward_extremities_for_room_txn, ) - async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]: + async def get_forward_extremities_for_room( + self, room_id: str + ) -> List[Dict[str, Any]]: """Get list of forward extremities for a room.""" - def get_forward_extremities_for_room_txn(txn): + def get_forward_extremities_for_room_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Any]]: sql = """ SELECT event_id, state_group, depth, received_ts FROM event_forward_extremities diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 434986fa64..cf842803bc 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +19,7 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -49,7 +51,7 @@ class FilteringStore(SQLBaseStore): # Need an atomic transaction to SELECT the maximal ID so far then # INSERT a new one - def _do_txn(txn): + def _do_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT filter_id FROM user_filters " "WHERE user_id = ? AND filter_json = ?" @@ -61,7 +63,7 @@ class FilteringStore(SQLBaseStore): sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" txn.execute(sql, (user_localpart,)) - max_id = txn.fetchone()[0] + max_id = txn.fetchone()[0] # type: ignore[index] if max_id is None: filter_id = 0 else: diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 3d0df0cbd4..a540f7fb26 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from types import TracebackType -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type +from typing import TYPE_CHECKING, Optional, Tuple, Type from weakref import WeakValueDictionary from twisted.internet.interfaces import IReactorCore @@ -62,7 +62,9 @@ class LockStore(SQLBaseStore): # A map from `(lock_name, lock_key)` to the token of any locks that we # think we currently hold. - self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary() + self._live_tokens: WeakValueDictionary[ + Tuple[str, str], Lock + ] = WeakValueDictionary() # When we shut down we want to remove the locks. Technically this can # lead to a race, as we may drop the lock while we are still processing. diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 717487be28..1b076683f7 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -13,10 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -46,7 +61,12 @@ class MediaSortOrder(Enum): class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -102,13 +122,15 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): self._drop_media_index_without_method, ) - async def _drop_media_index_without_method(self, progress, batch_size): + async def _drop_media_index_without_method( + self, progress: JsonDict, batch_size: int + ) -> int: """background update handler which removes the old constraints. Note that this is only run on postgres. """ - def f(txn): + def f(txn: LoggingTransaction) -> None: txn.execute( "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key" ) @@ -126,7 +148,12 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -174,7 +201,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): plus the total count of all the user's media """ - def get_local_media_by_user_paginate_txn(txn): + def get_local_media_by_user_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: # Set ordering order_by_column = MediaSortOrder(order_by).value @@ -184,14 +213,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): else: order = "ASC" - args = [user_id] + args: List[Union[str, int]] = [user_id] sql = """ SELECT COUNT(*) as total_media FROM local_media_repository WHERE user_id = ? """ txn.execute(sql, args) - count = txn.fetchone()[0] + count = txn.fetchone()[0] # type: ignore[index] sql = """ SELECT @@ -268,7 +297,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) sql += sql_keep - def _get_local_media_before_txn(txn): + def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (before_ts, before_ts, size_gt)) return [row[0] for row in txn] @@ -278,13 +307,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_local_media( self, - media_id, - media_type, - time_now_ms, - upload_name, - media_length, - user_id, - url_cache=None, + media_id: str, + media_type: str, + time_now_ms: int, + upload_name: Optional[str], + media_length: int, + user_id: UserID, + url_cache: Optional[str] = None, ) -> None: await self.db_pool.simple_insert( "local_media_repository", @@ -315,7 +344,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): None if the URL isn't cached. """ - def get_url_cache_txn(txn): + def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: # get the most recently cached result (relative to the given ts) sql = ( "SELECT response_code, etag, expires_ts, og, media_id, download_ts" @@ -359,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts - ): + ) -> None: await self.db_pool.simple_insert( "local_media_repository_url_cache", { @@ -390,13 +419,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_local_thumbnail( self, - media_id, - thumbnail_width, - thumbnail_height, - thumbnail_type, - thumbnail_method, - thumbnail_length, - ): + media_id: str, + thumbnail_width: int, + thumbnail_height: int, + thumbnail_type: str, + thumbnail_method: str, + thumbnail_length: int, + ) -> None: await self.db_pool.simple_upsert( table="local_media_repository_thumbnails", keyvalues={ @@ -430,14 +459,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_cached_remote_media( self, - origin, - media_id, - media_type, - media_length, - time_now_ms, - upload_name, - filesystem_id, - ): + origin: str, + media_id: str, + media_type: str, + media_length: int, + time_now_ms: int, + upload_name: Optional[str], + filesystem_id: str, + ) -> None: await self.db_pool.simple_insert( "remote_media_cache", { @@ -458,7 +487,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): local_media: Iterable[str], remote_media: Iterable[Tuple[str, str]], time_ms: int, - ): + ) -> None: """Updates the last access time of the given media Args: @@ -467,7 +496,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): time_ms: Current time in milliseconds """ - def update_cache_txn(txn): + def update_cache_txn(txn: LoggingTransaction) -> None: sql = ( "UPDATE remote_media_cache SET last_access_ts = ?" " WHERE media_origin = ? AND media_id = ?" @@ -488,7 +517,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn ) @@ -542,15 +571,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_remote_media_thumbnail( self, - origin, - media_id, - filesystem_id, - thumbnail_width, - thumbnail_height, - thumbnail_type, - thumbnail_method, - thumbnail_length, - ): + origin: str, + media_id: str, + filesystem_id: str, + thumbnail_width: int, + thumbnail_height: int, + thumbnail_type: str, + thumbnail_method: str, + thumbnail_length: int, + ) -> None: await self.db_pool.simple_upsert( table="remote_media_cache_thumbnails", keyvalues={ @@ -566,7 +595,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_remote_media_thumbnail", ) - async def get_remote_media_before(self, before_ts): + async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]: sql = ( "SELECT media_origin, media_id, filesystem_id" " FROM remote_media_cache" @@ -602,7 +631,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " LIMIT 500" ) - def _get_expired_url_cache_txn(txn): + def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (now_ts,)) return [row[0] for row in txn] @@ -610,18 +639,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_expired_url_cache", _get_expired_url_cache_txn ) - async def delete_url_cache(self, media_ids): + async def delete_url_cache(self, media_ids: Collection[str]) -> None: if len(media_ids) == 0: return sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" - def _delete_url_cache_txn(txn): + def _delete_url_cache_txn(txn: LoggingTransaction) -> None: txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) - return await self.db_pool.runInteraction( - "delete_url_cache", _delete_url_cache_txn - ) + await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn) async def get_url_cache_media_before(self, before_ts: int) -> List[str]: sql = ( @@ -631,7 +658,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " LIMIT 500" ) - def _get_url_cache_media_before_txn(txn): + def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (before_ts,)) return [row[0] for row in txn] @@ -639,11 +666,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_url_cache_media_before", _get_url_cache_media_before_txn ) - async def delete_url_cache_media(self, media_ids): + async def delete_url_cache_media(self, media_ids: Collection[str]) -> None: if len(media_ids) == 0: return - def _delete_url_cache_media_txn(txn): + def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None: sql = "DELETE FROM local_media_repository WHERE media_id = ?" txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) @@ -652,6 +679,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index 2aac64901b..a46685219f 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -1,6 +1,21 @@ +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction class OpenIdStore(SQLBaseStore): @@ -20,7 +35,7 @@ class OpenIdStore(SQLBaseStore): async def get_user_id_for_open_id_token( self, token: str, ts_now_ms: int ) -> Optional[str]: - def get_user_id_for_token_txn(txn): + def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]: sql = ( "SELECT user_id FROM open_id_tokens" " WHERE token = ? AND ? <= ts_valid_until_ms" diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index dd8e27e226..e197b7203e 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo @@ -104,7 +105,7 @@ class ProfileWorkerStore(SQLBaseStore): desc="update_remote_profile_cache", ) - async def maybe_delete_remote_profile_cache(self, user_id): + async def maybe_delete_remote_profile_cache(self, user_id: str) -> None: """Check if we still care about the remote user's profile, and if we don't then remove their profile from the cache """ @@ -116,9 +117,9 @@ class ProfileWorkerStore(SQLBaseStore): desc="delete_remote_profile_cache", ) - async def is_subscribed_remote_profile_for_user(self, user_id): + async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool: """Check whether we are interested in a remote user's profile.""" - res = await self.db_pool.simple_select_one_onecol( + res: Optional[str] = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -139,13 +140,16 @@ class ProfileWorkerStore(SQLBaseStore): if res: return True + return False async def get_remote_profile_cache_entries_that_expire( self, last_checked: int ) -> List[Dict[str, str]]: """Get all users who haven't been checked since `last_checked`""" - def _get_remote_profile_cache_entries_that_expire_txn(txn): + def _get_remote_profile_cache_entries_that_expire_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, str]]: sql = """ SELECT user_id, displayname, avatar_url FROM remote_profile_cache diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6c7d6ba508..0e8c168667 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -84,26 +84,26 @@ class TokenLookupResult: return self.user_id -@attr.s(frozen=True, slots=True) +@attr.s(auto_attribs=True, frozen=True, slots=True) class RefreshTokenLookupResult: """Result of looking up a refresh token.""" - user_id = attr.ib(type=str) + user_id: str """The user this token belongs to.""" - device_id = attr.ib(type=str) + device_id: str """The device associated with this refresh token.""" - token_id = attr.ib(type=int) + token_id: int """The ID of this refresh token.""" - next_token_id = attr.ib(type=Optional[int]) + next_token_id: Optional[int] """The ID of the refresh token which replaced this one.""" - has_next_refresh_token_been_refreshed = attr.ib(type=bool) + has_next_refresh_token_been_refreshed: bool """True if the next refresh token was used for another refresh.""" - has_next_access_token_been_used = attr.ib(type=bool) + has_next_access_token_been_used: bool """True if the next access token was already used at least once.""" @@ -476,7 +476,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): shadow_banned: true iff the user is to be shadow-banned, false otherwise. """ - def set_shadow_banned_txn(txn): + def set_shadow_banned_txn(txn: LoggingTransaction) -> None: user_id = user.to_string() self.db_pool.simple_update_one_txn( txn, @@ -1198,8 +1198,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): expiration_ts = now_ms + self._account_validity_period if use_delta: + assert self._account_validity_startup_job_max_delta is not None expiration_ts = random.randrange( - expiration_ts - self._account_validity_startup_job_max_delta, + int(expiration_ts - self._account_validity_startup_job_max_delta), expiration_ts, ) @@ -1728,11 +1729,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): ) self.db_pool.updates.register_background_update_handler( - "user_threepids_grandfather", self._bg_user_threepids_grandfather + "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) - self.db_pool.updates.register_background_update_handler( - "users_set_deactivated_flag", self._background_update_set_deactivated_flag + self.db_pool.updates.register_noop_background_update( + "user_threepids_grandfather" ) self.db_pool.updates.register_background_index_update( @@ -1805,35 +1806,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): return nb_processed - async def _bg_user_threepids_grandfather(self, progress, batch_size): - """We now track which identity servers a user binds their 3PID to, so - we need to handle the case of existing bindings where we didn't track - this. - - We do this by grandfathering in existing user threepids assuming that - they used one of the server configured trusted identity servers. - """ - id_servers = set(self.config.registration.trusted_third_party_id_servers) - - def _bg_user_threepids_grandfather_txn(txn): - sql = """ - INSERT INTO user_threepid_id_server - (user_id, medium, address, id_server) - SELECT user_id, medium, address, ? - FROM user_threepids - """ - - txn.execute_batch(sql, [(id_server,) for id_server in id_servers]) - - if id_servers: - await self.db_pool.runInteraction( - "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn - ) - - await self.db_pool.updates._end_background_update("user_threepids_grandfather") - - return 1 - async def set_user_deactivated_status( self, user_id: str, deactivated: bool ) -> None: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 53576ad52f..0a43acda07 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -20,7 +20,7 @@ import attr from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( AggregationPaginationToken, @@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def event_includes_relation(self, event_id: str) -> bool: + """Check if the given event relates to another event. + + An event has a relation if it has a valid m.relates_to with a rel_type + and event_id in the content: + + { + "content": { + "m.relates_to": { + "rel_type": "m.replace", + "event_id": "$other_event_id" + } + } + } + + Args: + event_id: The event to check. + + Returns: + True if the event includes a valid relation. + """ + + result = await self.db_pool.simple_select_one_onecol( + table="event_relations", + keyvalues={"event_id": event_id}, + retcol="event_id", + allow_none=True, + desc="event_includes_relation", + ) + return result is not None + + async def event_is_target_of_relation(self, parent_id: str) -> bool: + """Check if the given event is the target of another event's relation. + + An event is the target of an event relation if it has a valid + m.relates_to with a rel_type and event_id pointing to parent_id in the + content: + + { + "content": { + "m.relates_to": { + "rel_type": "m.replace", + "event_id": "$parent_id" + } + } + } + + Args: + parent_id: The event to check. + + Returns: + True if the event is the target of another event's relation. + """ + + result = await self.db_pool.simple_select_one_onecol( + table="event_relations", + keyvalues={"relates_to_id": parent_id}, + retcol="event_id", + allow_none=True, + desc="event_is_target_of_relation", + ) + return result is not None + @cached(tree=True) async def get_aggregation_groups_for_event( self, @@ -334,6 +397,62 @@ class RelationsWorkerStore(SQLBaseStore): return count, latest_event + async def events_have_relations( + self, + parent_ids: List[str], + relation_senders: Optional[List[str]], + relation_types: Optional[List[str]], + ) -> List[str]: + """Check which events have a relationship from the given senders of the + given types. + + Args: + parent_ids: The events being annotated + relation_senders: The relation senders to check. + relation_types: The relation types to check. + + Returns: + True if the event has at least one relationship from one of the given senders of the given type. + """ + # If no restrictions are given then the event has the required relations. + if not relation_senders and not relation_types: + return parent_ids + + sql = """ + SELECT relates_to_id FROM event_relations + INNER JOIN events USING (event_id) + WHERE + %s; + """ + + def _get_if_events_have_relations(txn) -> List[str]: + clauses: List[str] = [] + clause, args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", parent_ids + ) + clauses.append(clause) + + if relation_senders: + clause, temp_args = make_in_list_sql_clause( + txn.database_engine, "sender", relation_senders + ) + clauses.append(clause) + args.extend(temp_args) + if relation_types: + clause, temp_args = make_in_list_sql_clause( + txn.database_engine, "relation_type", relation_types + ) + clauses.append(clause) + args.extend(temp_args) + + txn.execute(sql % " AND ".join(clauses), args) + + return [row[0] for row in txn] + + return await self.db_pool.runInteraction( + "get_if_events_have_relations", _get_if_events_have_relations + ) + async def has_user_annotated_event( self, parent_id: str, event_type: str, aggregation_key: str, sender: str ) -> bool: diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index cefc77fa0f..7d694d852d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -397,6 +397,20 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + async def room_is_blocked_by(self, room_id: str) -> Optional[str]: + """ + Function to retrieve user who has blocked the room. + user_id is non-nullable + It returns None if the room is not blocked. + """ + return await self.db_pool.simple_select_one_onecol( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + allow_none=True, + desc="room_is_blocked_by", + ) + async def get_rooms_paginate( self, start: int, @@ -1751,7 +1765,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) async def block_room(self, room_id: str, user_id: str) -> None: - """Marks the room as blocked. Can be called multiple times. + """Marks the room as blocked. + + Can be called multiple times (though we'll only track the last user to + block this room). + + Can be called on a room unknown to this homeserver. Args: room_id: Room to block @@ -1770,3 +1789,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.is_room_blocked, (room_id,), ) + + async def unblock_room(self, room_id: str) -> None: + """Remove the room from blocking list. + + Args: + room_id: Room to unblock + """ + await self.db_pool.simple_delete( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + desc="unblock_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py index 97b2618437..39e80f6f5b 100644 --- a/synapse/storage/databases/main/room_batch.py +++ b/synapse/storage/databases/main/room_batch.py @@ -39,13 +39,11 @@ class RoomBatchStore(SQLBaseStore): async def store_state_group_id_for_event_id( self, event_id: str, state_group_id: int - ) -> Optional[str]: - { - await self.db_pool.simple_upsert( - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - values={"state_group": state_group_id, "event_id": event_id}, - # Unique constraint on event_id so we don't have to lock - lock=False, - ) - } + ) -> None: + await self.db_pool.simple_upsert( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + values={"state_group": state_group_id, "event_id": event_id}, + # Unique constraint on event_id so we don't have to lock + lock=False, + ) diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index ab2159c2d3..3201623fe4 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -63,12 +63,12 @@ class SignatureWorkerStore(SQLBaseStore): A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash. """ hashes = await self.get_event_reference_hashes(event_ids) - hashes = { + encoded_hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} for e_id, h in hashes.items() } - return list(hashes.items()) + return list(encoded_hashes.items()) def _get_event_reference_hashes_txn( self, txn: Cursor, event_id: str diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index a89747d741..7f3624b128 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -16,11 +16,17 @@ import logging from typing import Any, Dict, List, Tuple from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction +from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) class StateDeltasStore(SQLBaseStore): + # This class must be mixed in with a child class which provides the following + # attribute. TODO: can we get static analysis to enforce this? + _curr_state_delta_stream_cache: StreamChangeCache + async def get_current_state_deltas( self, prev_stream_id: int, max_stream_id: int ) -> Tuple[int, List[Dict[str, Any]]]: @@ -60,7 +66,9 @@ class StateDeltasStore(SQLBaseStore): # max_stream_id. return max_stream_id, [] - def get_current_state_deltas_txn(txn): + def get_current_state_deltas_txn( + txn: LoggingTransaction, + ) -> Tuple[int, List[Dict[str, Any]]]: # First we calculate the max stream id that will give us less than # N results. # We arbitrarily limit to 100 stream_id entries to ensure we don't @@ -106,7 +114,9 @@ class StateDeltasStore(SQLBaseStore): "get_current_state_deltas", get_current_state_deltas_txn ) - def _get_max_stream_id_in_current_state_deltas_txn(self, txn): + def _get_max_stream_id_in_current_state_deltas_txn( + self, txn: LoggingTransaction + ) -> int: return self.db_pool.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", @@ -114,7 +124,7 @@ class StateDeltasStore(SQLBaseStore): retcol="COALESCE(MAX(stream_id), -1)", ) - async def get_max_stream_id_in_current_state_deltas(self): + async def get_max_stream_id_in_current_state_deltas(self) -> int: return await self.db_pool.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index dc7884b1c0..42dc807d17 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -272,31 +272,37 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: args = [] if event_filter.types: - clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types)) + clauses.append( + "(%s)" % " OR ".join("event.type = ?" for _ in event_filter.types) + ) args.extend(event_filter.types) for typ in event_filter.not_types: - clauses.append("type != ?") + clauses.append("event.type != ?") args.append(typ) if event_filter.senders: - clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)) + clauses.append( + "(%s)" % " OR ".join("event.sender = ?" for _ in event_filter.senders) + ) args.extend(event_filter.senders) for sender in event_filter.not_senders: - clauses.append("sender != ?") + clauses.append("event.sender != ?") args.append(sender) if event_filter.rooms: - clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)) + clauses.append( + "(%s)" % " OR ".join("event.room_id = ?" for _ in event_filter.rooms) + ) args.extend(event_filter.rooms) for room_id in event_filter.not_rooms: - clauses.append("room_id != ?") + clauses.append("event.room_id != ?") args.append(room_id) if event_filter.contains_url: - clauses.append("contains_url = ?") + clauses.append("event.contains_url = ?") args.append(event_filter.contains_url) # We're only applying the "labels" filter on the database query, because applying the @@ -307,6 +313,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels)) args.extend(event_filter.labels) + # Filter on relation_senders / relation types from the joined tables. + if event_filter.relation_senders: + clauses.append( + "(%s)" + % " OR ".join( + "related_event.sender = ?" for _ in event_filter.relation_senders + ) + ) + args.extend(event_filter.relation_senders) + + if event_filter.relation_types: + clauses.append( + "(%s)" + % " OR ".join("relation_type = ?" for _ in event_filter.relation_types) + ) + args.extend(event_filter.relation_types) + return " AND ".join(clauses), args @@ -1116,7 +1139,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): bounds = generate_pagination_where_clause( direction=direction, - column_names=("topological_ordering", "stream_ordering"), + column_names=("event.topological_ordering", "event.stream_ordering"), from_token=from_bound, to_token=to_bound, engine=self.database_engine, @@ -1133,32 +1156,51 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): select_keywords = "SELECT" join_clause = "" + # Using DISTINCT in this SELECT query is quite expensive, because it + # requires the engine to sort on the entire (not limited) result set, + # i.e. the entire events table. Only use it in scenarios that could result + # in the same event ID occurring multiple times in the results. + needs_distinct = False if event_filter and event_filter.labels: # If we're not filtering on a label, then joining on event_labels will # return as many row for a single event as the number of labels it has. To # avoid this, only join if we're filtering on at least one label. - join_clause = """ + join_clause += """ LEFT JOIN event_labels USING (event_id, room_id, topological_ordering) """ if len(event_filter.labels) > 1: - # Using DISTINCT in this SELECT query is quite expensive, because it - # requires the engine to sort on the entire (not limited) result set, - # i.e. the entire events table. We only need to use it when we're - # filtering on more than two labels, because that's the only scenario - # in which we can possibly to get multiple times the same event ID in - # the results. - select_keywords += "DISTINCT" + # Multiple labels could cause the same event to appear multiple times. + needs_distinct = True + + # If there is a filter on relation_senders and relation_types join to the + # relations table. + if event_filter and ( + event_filter.relation_senders or event_filter.relation_types + ): + # Filtering by relations could cause the same event to appear multiple + # times (since there's no limit on the number of relations to an event). + needs_distinct = True + join_clause += """ + LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id) + """ + if event_filter.relation_senders: + join_clause += """ + LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) + """ + + if needs_distinct: + select_keywords += " DISTINCT" sql = """ %(select_keywords)s - event_id, instance_name, - topological_ordering, stream_ordering - FROM events + event.event_id, event.instance_name, + event.topological_ordering, event.stream_ordering + FROM events AS event %(join_clause)s - WHERE outlier = ? AND room_id = ? AND %(bounds)s - ORDER BY topological_ordering %(order)s, - stream_ordering %(order)s LIMIT ? + WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s + ORDER BY event.topological_ordering %(order)s, + event.stream_ordering %(order)s LIMIT ? """ % { "select_keywords": select_keywords, "join_clause": join_clause, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index f93ff0a545..8f510de53d 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -1,5 +1,6 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +15,10 @@ # limitations under the License. import logging -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, cast from synapse.storage._base import db_to_json +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.types import JsonDict from synapse.util import json_encoder @@ -50,7 +52,7 @@ class TagsWorkerStore(AccountDataWorkerStore): async def get_all_updated_tags( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]: """Get updates for tags replication stream. Args: @@ -75,7 +77,9 @@ class TagsWorkerStore(AccountDataWorkerStore): if last_id == current_id: return [], current_id, False - def get_all_updated_tags_txn(txn): + def get_all_updated_tags_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str]]: sql = ( "SELECT stream_id, user_id, room_id" " FROM room_tags_revisions as r" @@ -83,13 +87,16 @@ class TagsWorkerStore(AccountDataWorkerStore): " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + # mypy doesn't understand what the query is selecting. + return cast(List[Tuple[int, str, str]], txn.fetchall()) tag_ids = await self.db_pool.runInteraction( "get_all_updated_tags", get_all_updated_tags_txn ) - def get_tag_content(txn, tag_ids): + def get_tag_content( + txn: LoggingTransaction, tag_ids + ) -> List[Tuple[int, Tuple[str, str, str]]]: sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" results = [] for stream_id, user_id, room_id in tag_ids: @@ -127,15 +134,15 @@ class TagsWorkerStore(AccountDataWorkerStore): given version Args: - user_id(str): The user to get the tags for. - stream_id(int): The earliest update to get for the user. + user_id: The user to get the tags for. + stream_id: The earliest update to get for the user. Returns: A mapping from room_id strings to lists of tag strings for all the rooms that changed since the stream_id token. """ - def get_updated_tags_txn(txn): + def get_updated_tags_txn(txn: LoggingTransaction) -> List[str]: sql = ( "SELECT room_id from room_tags_revisions" " WHERE user_id = ? AND stream_id > ?" @@ -200,7 +207,7 @@ class TagsWorkerStore(AccountDataWorkerStore): content_json = json_encoder.encode(content) - def add_tag_txn(txn, next_id): + def add_tag_txn(txn: LoggingTransaction, next_id: int) -> None: self.db_pool.simple_upsert_txn( txn, table="room_tags", @@ -224,7 +231,7 @@ class TagsWorkerStore(AccountDataWorkerStore): """ assert self._can_write_to_account_data - def remove_tag_txn(txn, next_id): + def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: sql = ( "DELETE FROM room_tags " " WHERE user_id = ? AND room_id = ? AND tag = ?" diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index 1ecdd40c38..f79006533f 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -14,11 +14,12 @@ from typing import Dict, Iterable -from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.util.caches.descriptors import cached, cachedList -class UserErasureWorkerStore(SQLBaseStore): +class UserErasureWorkerStore(CacheInvalidationWorkerStore): @cached() async def is_user_erased(self, user_id: str) -> bool: """ @@ -69,7 +70,7 @@ class UserErasureStore(UserErasureWorkerStore): user_id: full user_id to be erased """ - def f(txn): + def f(txn: LoggingTransaction) -> None: # first check if they are already in the list txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) if txn.fetchone(): @@ -89,7 +90,7 @@ class UserErasureStore(UserErasureWorkerStore): user_id: full user_id to be un-erased """ - def f(txn): + def f(txn: LoggingTransaction) -> None: # first check if they are already in the list txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) if not txn.fetchone(): diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index a1d2332326..3a00ed6835 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -45,10 +45,13 @@ Changes in SCHEMA_VERSION = 64: Changes in SCHEMA_VERSION = 65: - MSC2716: Remove unique event_id constraint from insertion_event_edges because an insertion event can have multiple edges. + - Remove unused tables `user_stats_historical` and `room_stats_historical`. """ -SCHEMA_COMPAT_VERSION = 60 # 60: "outlier" not in internal_metadata. +SCHEMA_COMPAT_VERSION = ( + 61 # 61: Remove unused tables `user_stats_historical` and `room_stats_historical` +) """Limit on how far the synapse codebase can be rolled back without breaking db compat This value is stored in the database, and checked on startup. If the value in the diff --git a/synapse/storage/schema/main/delta/65/05_remove_room_stats_historical_and_user_stats_historical.sql b/synapse/storage/schema/main/delta/65/05_remove_room_stats_historical_and_user_stats_historical.sql new file mode 100644 index 0000000000..a145180e7a --- /dev/null +++ b/synapse/storage/schema/main/delta/65/05_remove_room_stats_historical_and_user_stats_historical.sql @@ -0,0 +1,19 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + -- Remove unused tables room_stats_historical and user_stats_historical + -- which have not been read or written since schema version 61. + DROP TABLE IF EXISTS room_stats_historical; + DROP TABLE IF EXISTS user_stats_historical; \ No newline at end of file diff --git a/synapse/storage/schema/main/delta/65/02_thread_relations.sql b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql index d60517f7b4..267b2cb539 100644 --- a/synapse/storage/schema/main/delta/65/02_thread_relations.sql +++ b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql @@ -15,4 +15,4 @@ -- Check old events for thread relations. INSERT INTO background_updates (ordering, update_name, progress_json) VALUES - (6502, 'event_thread_relation', '{}'); + (6507, 'event_arbitrary_relations', '{}'); diff --git a/synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql b/synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql new file mode 100644 index 0000000000..d79455c2ce --- /dev/null +++ b/synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql @@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Background update to clear the inboxes of hidden and deleted devices. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6508, 'remove_dead_devices_from_device_inbox', '{}'); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 670811611f..ac56bc9a05 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import heapq import logging import threading @@ -87,7 +89,25 @@ def _load_current_id( return (max if step > 0 else min)(current_id, step) -class StreamIdGenerator: +class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): + @abc.abstractmethod + def get_next(self) -> AsyncContextManager[int]: + raise NotImplementedError() + + @abc.abstractmethod + def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: + raise NotImplementedError() + + @abc.abstractmethod + def get_current_token(self) -> int: + raise NotImplementedError() + + @abc.abstractmethod + def get_current_token_for_writer(self, instance_name: str) -> int: + raise NotImplementedError() + + +class StreamIdGenerator(AbstractStreamIdGenerator): """Used to generate new stream ids when persisting events while keeping track of which transactions have been completed. @@ -209,7 +229,7 @@ class StreamIdGenerator: return self.get_current_token() -class MultiWriterIdGenerator: +class MultiWriterIdGenerator(AbstractStreamIdGenerator): """An ID generator that tracks a stream that can have multiple writers. Uses a Postgres sequence to coordinate ID assignment, but positions of other diff --git a/synapse/types.py b/synapse/types.py index 364ecf7d45..fb72f19343 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -19,6 +19,7 @@ from collections import namedtuple from typing import ( TYPE_CHECKING, Any, + ClassVar, Dict, Mapping, MutableMapping, @@ -38,6 +39,7 @@ from zope.interface import Interface from twisted.internet.interfaces import ( IReactorCore, IReactorPluggableNameResolver, + IReactorSSL, IReactorTCP, IReactorThreads, IReactorTime, @@ -66,6 +68,7 @@ JsonDict = Dict[str, Any] # for mypy-zope to realize it is an interface. class ISynapseReactor( IReactorTCP, + IReactorSSL, IReactorPluggableNameResolver, IReactorTime, IReactorCore, @@ -217,7 +220,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta): 'domain' : The domain part of the name """ - SIGIL: str = abc.abstractproperty() # type: ignore + SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore localpart = attr.ib(type=str) domain = attr.ib(type=str) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index abf53d149d..95f23e27b6 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -16,7 +16,7 @@ import json import logging import re import typing -from typing import Any, Callable, Dict, Generator, Pattern +from typing import Any, Callable, Dict, Generator, Optional, Pattern import attr from frozendict import frozendict @@ -110,7 +110,9 @@ class Clock: """Returns the current system time in milliseconds since epoch.""" return int(self.time() * 1000) - def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall: + def looping_call( + self, f: Callable, msec: float, *args: Any, **kwargs: Any + ) -> LoopingCall: """Call a function repeatedly. Waits `msec` initially before calling `f` for the first time. @@ -130,20 +132,22 @@ class Clock: d.addErrback(log_failure, "Looping call died", consumeErrors=False) return call - def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall: + def call_later( + self, delay: float, callback: Callable, *args: Any, **kwargs: Any + ) -> IDelayedCall: """Call something later Note that the function will be called with no logcontext, so if it is anything other than trivial, you probably want to wrap it in run_as_background_process. Args: - delay(float): How long to wait in seconds. - callback(function): Function to call + delay: How long to wait in seconds. + callback: Function to call *args: Postional arguments to pass to function. **kwargs: Key arguments to pass to function. """ - def wrapped_callback(*args, **kwargs): + def wrapped_callback(*args: Any, **kwargs: Any) -> None: with context.PreserveLoggingContext(): callback(*args, **kwargs) @@ -158,25 +162,29 @@ class Clock: raise -def log_failure(failure, msg, consumeErrors=True): +def log_failure( + failure: Failure, msg: str, consumeErrors: bool = True +) -> Optional[Failure]: """Creates a function suitable for passing to `Deferred.addErrback` that logs any failures that occur. Args: - msg (str): Message to log - consumeErrors (bool): If true consumes the failure, otherwise passes - on down the callback chain + failure: The Failure to log + msg: Message to log + consumeErrors: If true consumes the failure, otherwise passes on down + the callback chain Returns: - func(Failure) + The Failure if consumeErrors is false. None, otherwise. """ logger.error( - msg, exc_info=(failure.type, failure.value, failure.getTracebackObject()) + msg, exc_info=(failure.type, failure.value, failure.getTracebackObject()) # type: ignore[arg-type] ) if not consumeErrors: return failure + return None def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern: diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 96efc5f3e3..20ce294209 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -27,20 +27,20 @@ from typing import ( Generic, Hashable, Iterable, + Iterator, Optional, Set, TypeVar, Union, + cast, ) import attr from typing_extensions import ContextManager from twisted.internet import defer -from twisted.internet.base import ReactorBase from twisted.internet.defer import CancelledError from twisted.internet.interfaces import IReactorTime -from twisted.python import failure from twisted.python.failure import Failure from synapse.logging.context import ( @@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]): object.__setattr__(self, "_result", None) object.__setattr__(self, "_observers", []) - def callback(r): + def callback(r: _T) -> _T: object.__setattr__(self, "_result", (True, r)) # once we have set _result, no more entries will be added to _observers, @@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]): ) return r - def errback(f): + def errback(f: Failure) -> Optional[Failure]: object.__setattr__(self, "_result", (False, f)) # once we have set _result, no more entries will be added to _observers, @@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]): for observer in observers: # This is a little bit of magic to correctly propagate stack # traces when we `await` on one of the observer deferreds. - f.value.__failure__ = f + f.value.__failure__ = f # type: ignore[union-attr] try: observer.errback(f) except Exception as e: @@ -271,8 +271,7 @@ class Linearizer: if not clock: from twisted.internet import reactor - assert isinstance(reactor, ReactorBase) - clock = Clock(reactor) + clock = Clock(cast(IReactorTime, reactor)) self._clock = clock self.max_count = max_count @@ -315,7 +314,7 @@ class Linearizer: # will release the lock. @contextmanager - def _ctx_manager(_): + def _ctx_manager(_: None) -> Iterator[None]: try: yield finally: @@ -356,7 +355,7 @@ class Linearizer: new_defer = make_deferred_yieldable(defer.Deferred()) entry.deferreds[new_defer] = 1 - def cb(_r): + def cb(_r: None) -> "defer.Deferred[None]": logger.debug("Acquired linearizer lock %r for key %r", self.name, key) entry.count += 1 @@ -372,7 +371,7 @@ class Linearizer: # code must be synchronous, so this is the only sensible place.) return self._clock.sleep(0) - def eb(e): + def eb(e: Failure) -> Failure: logger.info("defer %r got err %r", new_defer, e) if isinstance(e, CancelledError): logger.debug( @@ -436,7 +435,7 @@ class ReadWriteLock: await make_deferred_yieldable(curr_writer) @contextmanager - def _ctx_manager(): + def _ctx_manager() -> Iterator[None]: try: yield finally: @@ -465,7 +464,7 @@ class ReadWriteLock: await make_deferred_yieldable(defer.gatherResults(to_wait_on)) @contextmanager - def _ctx_manager(): + def _ctx_manager() -> Iterator[None]: try: yield finally: @@ -525,7 +524,7 @@ def timeout_deferred( delayed_call = reactor.callLater(timeout, time_it_out) - def convert_cancelled(value: failure.Failure): + def convert_cancelled(value: Failure) -> Failure: # if the original deferred was cancelled, and our timeout has fired, then # the reason it was cancelled was due to our timeout. Turn the CancelledError # into a TimeoutError. @@ -535,7 +534,7 @@ def timeout_deferred( deferred.addErrback(convert_cancelled) - def cancel_timeout(result): + def cancel_timeout(result: _T) -> _T: # stop the pending call to cancel the deferred if it's been fired if delayed_call.active(): delayed_call.cancel() @@ -543,11 +542,11 @@ def timeout_deferred( deferred.addBoth(cancel_timeout) - def success_cb(val): + def success_cb(val: _T) -> None: if not new_d.called: new_d.callback(val) - def failure_cb(val): + def failure_cb(val: Failure) -> None: if not new_d.called: new_d.errback(val) @@ -558,13 +557,13 @@ def timeout_deferred( # This class can't be generic because it uses slots with attrs. # See: https://github.com/python-attrs/attrs/issues/313 -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class DoneAwaitable: # should be: Generic[R] """Simple awaitable that returns the provided value.""" - value = attr.ib(type=Any) # should be: R + value: Any # should be: R - def __await__(self): + def __await__(self) -> Any: return self def __iter__(self) -> "DoneAwaitable": diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index df4d61e4b6..15debd6c46 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -17,7 +17,7 @@ import logging import typing from enum import Enum, auto from sys import intern -from typing import Callable, Dict, Optional, Sized +from typing import Any, Callable, Dict, List, Optional, Sized import attr from prometheus_client.core import Gauge @@ -58,20 +58,20 @@ class EvictionReason(Enum): time = auto() -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class CacheMetric: - _cache = attr.ib() - _cache_type = attr.ib(type=str) - _cache_name = attr.ib(type=str) - _collect_callback = attr.ib(type=Optional[Callable]) + _cache: Sized + _cache_type: str + _cache_name: str + _collect_callback: Optional[Callable] - hits = attr.ib(default=0) - misses = attr.ib(default=0) + hits: int = 0 + misses: int = 0 eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib( factory=collections.Counter ) - memory_usage = attr.ib(default=None) + memory_usage: Optional[int] = None def inc_hits(self) -> None: self.hits += 1 @@ -89,13 +89,14 @@ class CacheMetric: self.memory_usage += memory def dec_memory_usage(self, memory: int) -> None: + assert self.memory_usage is not None self.memory_usage -= memory def clear_memory_usage(self) -> None: if self.memory_usage is not None: self.memory_usage = 0 - def describe(self): + def describe(self) -> List[str]: return [] def collect(self) -> None: @@ -118,8 +119,9 @@ class CacheMetric: self.eviction_size_by_reason[reason] ) cache_total.labels(self._cache_name).set(self.hits + self.misses) - if getattr(self._cache, "max_size", None): - cache_max_size.labels(self._cache_name).set(self._cache.max_size) + max_size = getattr(self._cache, "max_size", None) + if max_size: + cache_max_size.labels(self._cache_name).set(max_size) if TRACK_MEMORY_USAGE: # self.memory_usage can be None if nothing has been inserted @@ -193,7 +195,7 @@ KNOWN_KEYS = { } -def intern_string(string): +def intern_string(string: Optional[str]) -> Optional[str]: """Takes a (potentially) unicode string and interns it if it's ascii""" if string is None: return None @@ -204,7 +206,7 @@ def intern_string(string): return string -def intern_dict(dictionary): +def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]: """Takes a dictionary and interns well known keys and their values""" return { KNOWN_KEYS.get(key, key): _intern_known_values(key, value) @@ -212,7 +214,7 @@ def intern_dict(dictionary): } -def _intern_known_values(key, value): +def _intern_known_values(key: str, value: Any) -> Any: intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key") if key in intern_keys: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index da502aec11..3c4cc093af 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -289,7 +289,7 @@ class DeferredCache(Generic[KT, VT]): callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) - def invalidate(self, key) -> None: + def invalidate(self, key: KT) -> None: """Delete a key, or tree of entries If the cache is backed by a regular dict, then "key" must be of diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index b9dcca17f1..375cd443f1 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -19,12 +19,15 @@ import logging from typing import ( Any, Callable, + Dict, Generic, + Hashable, Iterable, Mapping, Optional, Sequence, Tuple, + Type, TypeVar, Union, cast, @@ -32,6 +35,7 @@ from typing import ( from weakref import WeakValueDictionary from twisted.internet import defer +from twisted.python.failure import Failure from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError @@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]): class _CacheDescriptorBase: - def __init__(self, orig: Callable[..., Any], num_args, cache_context=False): + def __init__( + self, + orig: Callable[..., Any], + num_args: Optional[int], + cache_context: bool = False, + ): self.orig = orig arg_spec = inspect.getfullargspec(orig) @@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase): def __init__( self, - orig, + orig: Callable[..., Any], max_entries: int = 1000, cache_context: bool = False, ): super().__init__(orig, num_args=None, cache_context=cache_context) self.max_entries = max_entries - def __get__(self, obj, owner): + def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: LruCache[CacheKey, Any] = LruCache( cache_name=self.orig.__name__, max_size=self.max_entries, @@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): sentinel = LruCacheDescriptor._Sentinel.sentinel @functools.wraps(self.orig) - def _wrapped(*args, **kwargs): + def _wrapped(*args: Any, **kwargs: Any) -> Any: invalidate_callback = kwargs.pop("on_invalidate", None) callbacks = (invalidate_callback,) if invalidate_callback else () @@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): return r1 + r2 Args: - num_args (int): number of positional arguments (excluding ``self`` and + num_args: number of positional arguments (excluding ``self`` and ``cache_context``) to use as cache keys. Defaults to all named args of the function. """ def __init__( self, - orig, - max_entries=1000, - num_args=None, - tree=False, - cache_context=False, - iterable=False, + orig: Callable[..., Any], + max_entries: int = 1000, + num_args: Optional[int] = None, + tree: bool = False, + cache_context: bool = False, + iterable: bool = False, prune_unread_entries: bool = True, ): super().__init__(orig, num_args=num_args, cache_context=cache_context) @@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): self.iterable = iterable self.prune_unread_entries = prune_unread_entries - def __get__(self, obj, owner): + def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: DeferredCache[CacheKey, Any] = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, @@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): get_cache_key = self.cache_key_builder @functools.wraps(self.orig) - def _wrapped(*args, **kwargs): + def _wrapped(*args: Any, **kwargs: Any) -> Any: # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): of results. """ - def __init__(self, orig, cached_method_name, list_name, num_args=None): + def __init__( + self, + orig: Callable[..., Any], + cached_method_name: str, + list_name: str, + num_args: Optional[int] = None, + ): """ Args: - orig (function) - cached_method_name (str): The name of the cached method. - list_name (str): Name of the argument which is the bulk lookup list - num_args (int): number of positional arguments (excluding ``self``, + orig + cached_method_name: The name of the cached method. + list_name: Name of the argument which is the bulk lookup list + num_args: number of positional arguments (excluding ``self``, but including list_name) to use as cache keys. Defaults to all named args of the function. """ @@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): % (self.list_name, cached_method_name) ) - def __get__(self, obj, objtype=None): + def __get__( + self, obj: Optional[Any], objtype: Optional[Type] = None + ) -> Callable[..., Any]: cached_method = getattr(obj, self.cached_method_name) cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) - def wrapped(*args, **kwargs): + def wrapped(*args: Any, **kwargs: Any) -> Any: # If we're passed a cache_context then we'll want to call its # invalidate() whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): results = {} - def update_results_dict(res, arg): + def update_results_dict(res: Any, arg: Hashable) -> None: results[arg] = res # list of deferreds to wait for @@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): # otherwise a tuple is used. if num_args == 1: - def arg_to_cache_key(arg): + def arg_to_cache_key(arg: Hashable) -> Hashable: return arg else: keylist = list(keyargs) - def arg_to_cache_key(arg): + def arg_to_cache_key(arg: Hashable) -> Hashable: keylist[self.list_pos] = arg return tuple(keylist) @@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): key = arg_to_cache_key(arg) cache.set(key, deferred, callback=invalidate_callback) - def complete_all(res): + def complete_all(res: Dict[Hashable, Any]) -> None: # the wrapped function has completed. It returns a # a dict. We can now resolve the observable deferreds in # the cache and update our own result map. @@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): deferreds_map[e].callback(val) results[e] = val - def errback(f): + def errback(f: Failure) -> Failure: # the wrapped function has failed. Invalidate any cache # entries we're supposed to be populating, and fail # their deferreds. diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index c3f72aa06d..67ee4c693b 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload import attr from typing_extensions import Literal +from twisted.internet import defer + from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util import Clock @@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]): # Don't bother starting the loop if things never expire return - def f(): + def f() -> "defer.Deferred[None]": return run_as_background_process( "prune_cache_%s" % self._cache_name, self._prune_cache ) @@ -157,7 +159,7 @@ class ExpiringCache(Generic[KT, VT]): self[key] = value return value - def _prune_cache(self) -> None: + async def _prune_cache(self) -> None: if not self._expiry_ms: # zero expiry time means don't expire. This should never get called # since we have this check in start too. @@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]): return False -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _CacheEntry: - time = attr.ib(type=int) - value = attr.ib() + time: int + value: Any diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 31097d6439..91837655f8 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -18,12 +18,13 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) -def user_left_room(distributor, user, room_id): +def user_left_room(distributor: "Distributor", user: UserID, room_id: str) -> None: distributor.fire("user_left_room", user=user, room_id=room_id) @@ -63,7 +64,7 @@ class Distributor: self.pre_registration[name] = [] self.pre_registration[name].append(observer) - def fire(self, name: str, *args, **kwargs) -> None: + def fire(self, name: str, *args: Any, **kwargs: Any) -> None: """Dispatches the given signal to the registered observers. Runs the observers as a background process. Does not return a deferred. @@ -95,7 +96,7 @@ class Signal: Each observer callable may return a Deferred.""" self.observers.append(observer) - def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]": + def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]": """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is not an error to fire a signal with no observers. @@ -103,7 +104,7 @@ class Signal: Returns a Deferred that will complete when all the observers have completed.""" - async def do(observer): + async def do(observer: Callable[..., Any]) -> Any: try: return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: @@ -120,5 +121,5 @@ class Signal: defer.gatherResults(deferreds, consumeErrors=True) ) - def __repr__(self): + def __repr__(self) -> str: return "<Signal name=%r>" % (self.name,) diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py index a447ce4e55..214eb17fbc 100644 --- a/synapse/util/gai_resolver.py +++ b/synapse/util/gai_resolver.py @@ -3,23 +3,52 @@ # We copy it here as we need to instantiate `GAIResolver` manually, but it is a # private class. - from socket import ( AF_INET, AF_INET6, AF_UNSPEC, SOCK_DGRAM, SOCK_STREAM, + AddressFamily, + SocketKind, gaierror, getaddrinfo, ) +from typing import ( + TYPE_CHECKING, + Callable, + List, + NoReturn, + Optional, + Sequence, + Tuple, + Type, + Union, +) from zope.interface import implementer from twisted.internet.address import IPv4Address, IPv6Address -from twisted.internet.interfaces import IHostnameResolver, IHostResolution +from twisted.internet.interfaces import ( + IAddress, + IHostnameResolver, + IHostResolution, + IReactorThreads, + IResolutionReceiver, +) from twisted.internet.threads import deferToThreadPool +if TYPE_CHECKING: + # The types below are copied from + # https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py + # so that the type hints can match the interfaces. + from twisted.python.runtime import platform + + if platform.supportsThreads(): + from twisted.python.threadpool import ThreadPool + else: + ThreadPool = object # type: ignore[misc, assignment] + @implementer(IHostResolution) class HostResolution: @@ -27,13 +56,13 @@ class HostResolution: The in-progress resolution of a given hostname. """ - def __init__(self, name): + def __init__(self, name: str): """ Create a L{HostResolution} with the given name. """ self.name = name - def cancel(self): + def cancel(self) -> NoReturn: # IHostResolution.cancel raise NotImplementedError() @@ -62,6 +91,17 @@ _socktypeToType = { } +_GETADDRINFO_RESULT = List[ + Tuple[ + AddressFamily, + SocketKind, + int, + str, + Union[Tuple[str, int], Tuple[str, int, int, int]], + ] +] + + @implementer(IHostnameResolver) class GAIResolver: """ @@ -69,7 +109,12 @@ class GAIResolver: L{getaddrinfo} in a thread. """ - def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo): + def __init__( + self, + reactor: IReactorThreads, + getThreadPool: Optional[Callable[[], "ThreadPool"]] = None, + getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo, + ): """ Create a L{GAIResolver}. @param reactor: the reactor to schedule result-delivery on @@ -89,14 +134,16 @@ class GAIResolver: ) self._getaddrinfo = getaddrinfo - def resolveHostName( + # The types on IHostnameResolver is incorrect in Twisted, see + # https://twistedmatrix.com/trac/ticket/10276 + def resolveHostName( # type: ignore[override] self, - resolutionReceiver, - hostName, - portNumber=0, - addressTypes=None, - transportSemantics="TCP", - ): + resolutionReceiver: IResolutionReceiver, + hostName: str, + portNumber: int = 0, + addressTypes: Optional[Sequence[Type[IAddress]]] = None, + transportSemantics: str = "TCP", + ) -> IHostResolution: """ See L{IHostnameResolver.resolveHostName} @param resolutionReceiver: see interface @@ -112,7 +159,7 @@ class GAIResolver: ] socketType = _transportToSocket[transportSemantics] - def get(): + def get() -> _GETADDRINFO_RESULT: try: return self._getaddrinfo( hostName, portNumber, addressFamily, socketType @@ -125,7 +172,7 @@ class GAIResolver: resolutionReceiver.resolutionBegan(resolution) @d.addCallback - def deliverResults(result): + def deliverResults(result: _GETADDRINFO_RESULT) -> None: for family, socktype, _proto, _cannoname, sockaddr in result: addrType = _afToType[family] resolutionReceiver.addressResolved( diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index b163643ca3..a0606851f7 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -92,9 +92,9 @@ def _resource_id(resource: Resource, path_seg: bytes) -> str: the mapping should looks like _resource_id(A,C) = B. Args: - resource (Resource): The *parent* Resourceb - path_seg (str): The name of the child Resource to be attached. + resource: The *parent* Resourceb + path_seg: The name of the child Resource to be attached. Returns: - str: A unique string which can be a key to the child Resource. + A unique string which can be a key to the child Resource. """ return "%s-%r" % (resource, path_seg) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index f8b2d7bea9..48b8195ca1 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -23,7 +23,7 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal from twisted.internet import defer -from twisted.internet.protocol import Factory +from twisted.internet.protocol import ServerFactory from synapse.config.server import ManholeConfig @@ -65,7 +65,7 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs= -----END RSA PRIVATE KEY-----""" -def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory: +def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory: """Starts a ssh listener with password authentication using the given username and password. Clients connecting to the ssh listener will find themselves in a colored python shell with @@ -105,7 +105,8 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory: factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment] factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment] - return factory + # ConchFactory is a Factory, not a ServerFactory, but they are identical. + return factory # type: ignore[return-value] class SynapseManhole(ColoredManhole): diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 1e784b3f1f..98ee49af6e 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -56,14 +56,22 @@ block_db_sched_duration = Counter( "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"] ) + +# This is dynamically created in InFlightGauge.__init__. +class _InFlightMetric(Protocol): + real_time_max: float + real_time_sum: float + + # Tracks the number of blocks currently active -in_flight = InFlightGauge( +in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge( "synapse_util_metrics_block_in_flight", "", labels=["block_name"], sub_metrics=["real_time_max", "real_time_sum"], ) + T = TypeVar("T", bound=Callable[..., Any]) @@ -180,7 +188,7 @@ class Measure: """ return self._logging_context.get_resource_usage() - def _update_in_flight(self, metrics) -> None: + def _update_in_flight(self, metrics: _InFlightMetric) -> None: """Gets called when processing in flight metrics""" assert self.start is not None duration = self.clock.time() - self.start |