summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/util/async_helpers.py154
-rw-r--r--synapse/util/caches/dictionary_cache.py13
-rw-r--r--synapse/util/caches/expiringcache.py3
-rw-r--r--synapse/util/caches/lrucache.py3
-rw-r--r--synapse/util/caches/response_cache.py33
-rw-r--r--synapse/util/caches/stream_change_cache.py23
-rw-r--r--synapse/util/events.py29
-rw-r--r--synapse/util/iterutils.py7
-rw-r--r--synapse/util/linked_list.py3
-rw-r--r--synapse/util/macaroons.py3
-rw-r--r--synapse/util/metrics.py15
-rw-r--r--synapse/util/msisdn.py51
-rw-r--r--synapse/util/patch_inline_callbacks.py6
-rw-r--r--synapse/util/ratelimitutils.py2
-rw-r--r--synapse/util/rust.py87
-rw-r--r--synapse/util/stringutils.py12
-rw-r--r--synapse/util/task_scheduler.py155
-rw-r--r--synapse/util/threepids.py123
-rw-r--r--synapse/util/wheel_timer.py6
19 files changed, 419 insertions, 309 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py

index 70139beef2..e596e1ed20 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -41,6 +41,7 @@ from typing import ( Hashable, Iterable, List, + Literal, Optional, Set, Tuple, @@ -51,7 +52,7 @@ from typing import ( ) import attr -from typing_extensions import Concatenate, Literal, ParamSpec +from typing_extensions import Concatenate, ParamSpec, Unpack from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -61,6 +62,7 @@ from twisted.python.failure import Failure from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, + run_coroutine_in_background, run_in_background, ) from synapse.util import Clock @@ -344,6 +346,7 @@ T1 = TypeVar("T1") T2 = TypeVar("T2") T3 = TypeVar("T3") T4 = TypeVar("T4") +T5 = TypeVar("T5") @overload @@ -402,6 +405,112 @@ def gather_results( # type: ignore[misc] return deferred.addCallback(tuple) +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]]]], +) -> Tuple[Optional[T1]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2], Optional[T3]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + Optional[Coroutine[Any, Any, T4]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + Optional[Coroutine[Any, Any, T4]], + Optional[Coroutine[Any, Any, T5]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ... + + +async def gather_optional_coroutines( + *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]], +) -> Tuple[Optional[T1], ...]: + """Helper function that allows waiting on multiple coroutines at once. + + The return value is a tuple of the return values of the coroutines in order. + + If a `None` is passed instead of a coroutine, it will be ignored and a None + is returned in the tuple. + + Note: For typechecking we need to have an explicit overload for each + distinct number of coroutines passed in. If you see type problems, it's + likely because you're using many arguments and you need to add a new + overload above. + """ + + try: + results = await make_deferred_yieldable( + defer.gatherResults( + [ + run_coroutine_in_background(coroutine) + for coroutine in coroutines + if coroutine is not None + ], + consumeErrors=True, + ) + ) + + results_iter = iter(results) + return tuple( + next(results_iter) if coroutine is not None else None + for coroutine in coroutines + ) + except defer.FirstError as dfe: + # unwrap the error from defer.gatherResults. + + # The raised exception's traceback only includes func() etc if + # the 'await' happens before the exception is thrown - ie if the failure + # happens *asynchronously* - otherwise Twisted throws away the traceback as it + # could be large. + # + # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe + # we could throw Twisted into the fires of Mordor. + + # suppress exception chaining, because the FirstError doesn't tell us anything + # very interesting. + assert isinstance(dfe.subFailure.value, BaseException) + raise dfe.subFailure.value from None + + @attr.s(slots=True, auto_attribs=True) class _LinearizerEntry: # The number of things executing. @@ -885,3 +994,46 @@ class AwakenableSleeper: # Cancel the sleep if we were woken up if call.active(): call.cancel() + + +class DeferredEvent: + """Like threading.Event but for async code""" + + def __init__(self, reactor: IReactorTime) -> None: + self._reactor = reactor + self._deferred: "defer.Deferred[None]" = defer.Deferred() + + def set(self) -> None: + if not self._deferred.called: + self._deferred.callback(None) + + def clear(self) -> None: + if self._deferred.called: + self._deferred = defer.Deferred() + + def is_set(self) -> bool: + return self._deferred.called + + async def wait(self, timeout_seconds: float) -> bool: + if self.is_set(): + return True + + # Create a deferred that gets called in N seconds + sleep_deferred: "defer.Deferred[None]" = defer.Deferred() + call = self._reactor.callLater(timeout_seconds, sleep_deferred.callback, None) + + try: + await make_deferred_yieldable( + defer.DeferredList( + [sleep_deferred, self._deferred], + fireOnOneCallback=True, + fireOnOneErrback=True, + consumeErrors=True, + ) + ) + finally: + # Cancel the sleep if we were woken up + if call.active(): + call.cancel() + + return self.is_set() diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 1e6696332f..14bd3ba3b0 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py
@@ -21,10 +21,19 @@ import enum import logging import threading -from typing import Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union +from typing import ( + Dict, + Generic, + Iterable, + Literal, + Optional, + Set, + Tuple, + TypeVar, + Union, +) import attr -from typing_extensions import Literal from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 8017c031ee..3198fdd2ed 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py
@@ -21,10 +21,9 @@ import logging from collections import OrderedDict -from typing import Any, Generic, Iterable, Optional, TypeVar, Union, overload +from typing import Any, Generic, Iterable, Literal, Optional, TypeVar, Union, overload import attr -from typing_extensions import Literal from twisted.internet import defer diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 481a1a621e..2e5efa3a52 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py
@@ -34,6 +34,7 @@ from typing import ( Generic, Iterable, List, + Literal, Optional, Set, Tuple, @@ -44,8 +45,6 @@ from typing import ( overload, ) -from typing_extensions import Literal - from twisted.internet import reactor from twisted.internet.interfaces import IReactorTime diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 96b7ca83dc..54b99134b9 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py
@@ -101,7 +101,13 @@ class ResponseCache(Generic[KV]): used rather than trying to compute a new response. """ - def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): + def __init__( + self, + clock: Clock, + name: str, + timeout_ms: float = 0, + enable_logging: bool = True, + ): self._result_cache: Dict[KV, ResponseCacheEntry] = {} self.clock = clock @@ -109,6 +115,7 @@ class ResponseCache(Generic[KV]): self._name = name self._metrics = register_cache("response_cache", name, self, resizable=False) + self._enable_logging = enable_logging def size(self) -> int: return len(self._result_cache) @@ -246,9 +253,12 @@ class ResponseCache(Generic[KV]): """ entry = self._get(key) if not entry: - logger.debug( - "[%s]: no cached result for [%s], calculating new one", self._name, key - ) + if self._enable_logging: + logger.debug( + "[%s]: no cached result for [%s], calculating new one", + self._name, + key, + ) context = ResponseCacheContext(cache_key=key) if cache_context: kwargs["cache_context"] = context @@ -269,12 +279,15 @@ class ResponseCache(Generic[KV]): return await make_deferred_yieldable(entry.result.observe()) result = entry.result.observe() - if result.called: - logger.info("[%s]: using completed cached result for [%s]", self._name, key) - else: - logger.info( - "[%s]: using incomplete cached result for [%s]", self._name, key - ) + if self._enable_logging: + if result.called: + logger.info( + "[%s]: using completed cached result for [%s]", self._name, key + ) + else: + logger.info( + "[%s]: using incomplete cached result for [%s]", self._name, key + ) span_context = entry.opentracing_span_context with start_active_span_follows_from( diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 16fcb00206..5ac8643eef 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py
@@ -142,9 +142,9 @@ class StreamChangeCache: """ assert isinstance(stream_pos, int) - # _cache is not valid at or before the earliest known stream position, so + # _cache is not valid before the earliest known stream position, so # return that the entity has changed. - if stream_pos <= self._earliest_known_stream_pos: + if stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() return True @@ -186,7 +186,7 @@ class StreamChangeCache: This will be all entities if the given stream position is at or earlier than the earliest known stream position. """ - if not self._cache or stream_pos <= self._earliest_known_stream_pos: + if not self._cache or stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() return set(entities) @@ -238,9 +238,9 @@ class StreamChangeCache: """ assert isinstance(stream_pos, int) - # _cache is not valid at or before the earliest known stream position, so + # _cache is not valid before the earliest known stream position, so # return that an entity has changed. - if stream_pos <= self._earliest_known_stream_pos: + if stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() return True @@ -270,9 +270,9 @@ class StreamChangeCache: """ assert isinstance(stream_pos, int) - # _cache is not valid at or before the earliest known stream position, so + # _cache is not valid before the earliest known stream position, so # return None to mark that it is unknown if an entity has changed. - if stream_pos <= self._earliest_known_stream_pos: + if stream_pos < self._earliest_known_stream_pos: return AllEntitiesChangedResult(None) changed_entities: List[EntityType] = [] @@ -314,6 +314,15 @@ class StreamChangeCache: self._entity_to_key[entity] = stream_pos self._evict() + def all_entities_changed(self, stream_pos: int) -> None: + """ + Mark all entities as changed. This is useful when the cache is invalidated and + there may be some potential change for all of the entities. + """ + self._cache.clear() + self._entity_to_key.clear() + self._earliest_known_stream_pos = stream_pos + def _evict(self) -> None: """ Ensure the cache has not exceeded the maximum size. diff --git a/synapse/util/events.py b/synapse/util/events.py new file mode 100644
index 0000000000..ad9b946578 --- /dev/null +++ b/synapse/util/events.py
@@ -0,0 +1,29 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# +# + +from synapse.util.stringutils import random_string + + +def generate_fake_event_id() -> str: + """ + Generate an event ID from random ASCII characters. + + This is primarily useful for generating fake event IDs in response to + requests from shadow-banned users. + + Returns: + A string intended to look like an event ID, but with no actual meaning. + """ + return "$" + random_string(43) diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index b73f690b88..0a6a30aab2 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py
@@ -30,14 +30,13 @@ from typing import ( Iterator, List, Mapping, + Protocol, Set, Sized, Tuple, TypeVar, ) -from typing_extensions import Protocol - T = TypeVar("T") S = TypeVar("S", bound="_SelfSlice") @@ -115,7 +114,7 @@ def sorted_topologically( # This is implemented by Kahn's algorithm. - degree_map = {node: 0 for node in nodes} + degree_map = dict.fromkeys(nodes, 0) reverse_graph: Dict[T, Set[T]] = {} for node, edges in graph.items(): @@ -165,7 +164,7 @@ def sorted_topologically_batched( persisted. """ - degree_map = {node: 0 for node in nodes} + degree_map = dict.fromkeys(nodes, 0) reverse_graph: Dict[T, Set[T]] = {} for node, edges in graph.items(): diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py
index e9a5fff211..87f801c0cf 100644 --- a/synapse/util/linked_list.py +++ b/synapse/util/linked_list.py
@@ -19,8 +19,7 @@ # # -"""A circular doubly linked list implementation. -""" +"""A circular doubly linked list implementation.""" import threading from typing import Generic, Optional, Type, TypeVar diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index 84ae226207..6fa15543ec 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py
@@ -22,12 +22,11 @@ """Utilities for manipulating macaroons""" -from typing import Callable, Optional +from typing import Callable, Literal, Optional import attr import pymacaroons from pymacaroons.exceptions import MacaroonVerificationFailedException -from typing_extensions import Literal from synapse.util import Clock, stringutils diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 517e79ce5f..6a389f7a7e 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py
@@ -22,10 +22,19 @@ import logging from functools import wraps from types import TracebackType -from typing import Awaitable, Callable, Dict, Generator, Optional, Type, TypeVar +from typing import ( + Awaitable, + Callable, + Dict, + Generator, + Optional, + Protocol, + Type, + TypeVar, +) from prometheus_client import CollectorRegistry, Counter, Metric -from typing_extensions import Concatenate, ParamSpec, Protocol +from typing_extensions import Concatenate, ParamSpec from synapse.logging.context import ( ContextResourceUsage, @@ -110,7 +119,7 @@ def measure_func( """ def wrapper( - func: Callable[Concatenate[HasClock, P], Awaitable[R]] + func: Callable[Concatenate[HasClock, P], Awaitable[R]], ) -> Callable[P, Awaitable[R]]: block_name = func.__name__ if name is None else name diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py deleted file mode 100644
index b6a784f0bc..0000000000 --- a/synapse/util/msisdn.py +++ /dev/null
@@ -1,51 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2017 Vector Creations Ltd -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -import phonenumbers - -from synapse.api.errors import SynapseError - - -def phone_number_to_msisdn(country: str, number: str) -> str: - """ - Takes an ISO-3166-1 2 letter country code and phone number and - returns an msisdn representing the canonical version of that - phone number. - - As an example, if `country` is "GB" and `number` is "7470674927", this - function will return "447470674927". - - Args: - country: ISO-3166-1 2 letter country code - number: Phone number in a national or international format - - Returns: - The canonical form of the phone number, as an msisdn. - Raises: - SynapseError if the number could not be parsed. - """ - try: - phoneNumber = phonenumbers.parse(number, country) - except phonenumbers.NumberParseException: - raise SynapseError(400, "Unable to parse phone number") - return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[ - 1: - ] diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 46dad32156..beea4d2888 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py
@@ -50,7 +50,7 @@ def do_patch() -> None: return def new_inline_callbacks( - f: Callable[P, Generator["Deferred[object]", object, T]] + f: Callable[P, Generator["Deferred[object]", object, T]], ) -> Callable[P, "Deferred[T]"]: @functools.wraps(f) def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]": @@ -162,7 +162,7 @@ def _check_yield_points( d = result.throwExceptionIntoGenerator(gen) else: d = gen.send(result) - except (StopIteration, defer._DefGen_Return) as e: + except StopIteration as e: if current_context() != expected_context: # This happens when the context is lost sometime *after* the # final yield and returning. E.g. we forgot to yield on a @@ -183,7 +183,7 @@ def _check_yield_points( ) ) changes.append(err) - # The `StopIteration` or `_DefGen_Return` contains the return value from the + # The `StopIteration` contains the return value from the # generator. return cast(T, e.value) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 8ead72bb7a..3f067b792c 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py
@@ -103,7 +103,7 @@ _rate_limiter_instances_lock = threading.Lock() def _get_counts_from_rate_limiter_instance( - count_func: Callable[["FederationRateLimiter"], int] + count_func: Callable[["FederationRateLimiter"], int], ) -> Mapping[Tuple[str, ...], int]: """Returns a count of something (slept/rejected hosts) by (metrics_name)""" # Cast to a list to prevent it changing while the Prometheus diff --git a/synapse/util/rust.py b/synapse/util/rust.py
index 0e35d6d188..37f43459f1 100644 --- a/synapse/util/rust.py +++ b/synapse/util/rust.py
@@ -19,9 +19,12 @@ # # +import json import os -import sys +import urllib.parse from hashlib import blake2b +from importlib.metadata import Distribution, PackageNotFoundError +from typing import Optional import synapse from synapse.synapse_rust import get_rust_file_digest @@ -32,22 +35,17 @@ def check_rust_lib_up_to_date() -> None: be rebuilt. """ - if not _dist_is_editable(): - return - - synapse_dir = os.path.dirname(synapse.__file__) - synapse_root = os.path.abspath(os.path.join(synapse_dir, "..")) - - # Double check we've not gone into site-packages... - if os.path.basename(synapse_root) == "site-packages": - return - - # ... and it looks like the root of a python project. - if not os.path.exists("pyproject.toml"): - return + # Get the location of the editable install. + synapse_root = get_synapse_source_directory() + if synapse_root is None: + return None # Get the hash of all Rust source files - hash = _hash_rust_files_in_directory(os.path.join(synapse_root, "rust", "src")) + rust_path = os.path.join(synapse_root, "rust", "src") + if not os.path.exists(rust_path): + return None + + hash = _hash_rust_files_in_directory(rust_path) if hash != get_rust_file_digest(): raise Exception("Rust module outdated. Please rebuild using `poetry install`") @@ -82,10 +80,55 @@ def _hash_rust_files_in_directory(directory: str) -> str: return hasher.hexdigest() -def _dist_is_editable() -> bool: - """Is distribution an editable install?""" - for path_item in sys.path: - egg_link = os.path.join(path_item, "matrix-synapse.egg-link") - if os.path.isfile(egg_link): - return True - return False +def get_synapse_source_directory() -> Optional[str]: + """Try and find the source directory of synapse for editable installs (like + those used in development). + + Returns None if not an editable install (or otherwise can't find the source + directory). + """ + + # Try and find the installed matrix-synapse package. + try: + package = Distribution.from_name("matrix-synapse") + except PackageNotFoundError: + # The package is not found, so it's not installed and so must be being + # pulled out from a local directory (usually the current one). + synapse_dir = os.path.dirname(synapse.__file__) + synapse_root = os.path.abspath(os.path.join(synapse_dir, "..")) + + # Double check we've not gone into site-packages... + if os.path.basename(synapse_root) == "site-packages": + return None + + # ... and it looks like the root of a python project. + if not os.path.exists("pyproject.toml"): + return None + + return synapse_root + + # Read the `direct_url.json` metadata for the package. This won't exist for + # packages installed via a repository/etc. + # c.f. https://packaging.python.org/en/latest/specifications/direct-url/ + direct_url_json = package.read_text("direct_url.json") + if direct_url_json is None: + return None + + # c.f. https://packaging.python.org/en/latest/specifications/direct-url/ for + # the format + direct_url_dict: dict = json.loads(direct_url_json) + + # `url` must exist as a key, and point to where we fetched the repo from. + project_url = urllib.parse.urlparse(direct_url_dict["url"]) + + # If its not a local file then we must have built the rust libs either a) + # after we downloaded the package, or b) we built the download wheel. + if project_url.scheme != "file": + return None + + # And finally if its not an editable install then the files can't have + # changed since we installed the package. + if not direct_url_dict.get("dir_info", {}).get("editable", False): + return None + + return project_url.path diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 13ff54b669..32b5bc00c9 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py
@@ -43,6 +43,14 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$") # MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") +# https://spec.matrix.org/v1.13/appendices/#common-namespaced-identifier-grammar +# +# At least one character, less than or equal to 255 characters. Must start with +# a-z, the rest is a-z, 0-9, -, _, or .. +# +# This doesn't check anything about validity of namespaces. +NAMESPACED_GRAMMAR = re.compile(r"^[a-z][a-z0-9_.-]{0,254}$") + def random_string(length: int) -> str: """Generate a cryptographically secure string of random letters. @@ -68,6 +76,10 @@ def is_ascii(s: bytes) -> bool: return True +def is_namedspaced_grammar(s: str) -> bool: + return bool(NAMESPACED_GRAMMAR.match(s)) + + def assert_valid_client_secret(client_secret: str) -> None: """Validate that a given string matches the client_secret defined by the spec""" if ( diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index 448960b297..4683d09cd7 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py
@@ -46,33 +46,43 @@ logger = logging.getLogger(__name__) class TaskScheduler: """ - This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background` - to launch a background task, or Twisted `deferLater` if we want to do so later on. - - The problem with that is that the tasks will just stop and never be resumed if synapse - is stopped for whatever reason. - - How this works: - - A function mapped to a named action should first be registered with `register_action`. - This function will be called when trying to resuming tasks after a synapse shutdown, - so this registration should happen when synapse is initialised, NOT right before scheduling - a task. - - A task can then be launched using this named action with `schedule_task`. A `params` dict - can be passed, and it will be available to the registered function when launched. This task - can be launch either now-ish, or later on by giving a `timestamp` parameter. - - The function may call `update_task` at any time to update the `result` of the task, - and this can be used to resume the task at a specific point and/or to convey a result to - the code launching the task. - You can also specify the `result` (and/or an `error`) when returning from the function. - - The reconciliation loop runs every minute, so this is not a precise scheduler. - There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already - full. In this regard, please take great care that scheduled tasks can actually finished. - For now there is no mechanism to stop a running task if it is stuck. - - Tasks will be run on the worker specified with `run_background_tasks_on` config, - or the main one by default. + This is a simple task scheduler designed for resumable tasks. Normally, + you'd use `run_in_background` to start a background task or Twisted's + `deferLater` if you want to run it later. + + The issue is that these tasks stop completely and won't resume if Synapse is + shut down for any reason. + + Here's how it works: + + - Register an Action: First, you need to register a function to a named + action using `register_action`. This function will be called to resume tasks + after a Synapse shutdown. Make sure to register it when Synapse initializes, + not right before scheduling the task. + + - Schedule a Task: You can launch a task linked to the named action + using `schedule_task`. You can pass a `params` dictionary, which will be + passed to the registered function when it's executed. Tasks can be scheduled + to run either immediately or later by specifying a `timestamp`. + + - Update Task: The function handling the task can call `update_task` at + any point to update the task's `result`. This lets you resume the task from + a specific point or pass results back to the code that scheduled it. When + the function completes, you can also return a `result` or an `error`. + + Things to keep in mind: + + - The reconciliation loop runs every minute, so this is not a high-precision + scheduler. + + - Only 10 tasks can run at the same time. If the pool is full, tasks may be + delayed. Make sure your scheduled tasks can actually finish. + + - Currently, there's no way to stop a task if it gets stuck. + + - Tasks will run on the worker defined by the `run_background_tasks_on` + setting in your configuration. If no worker is specified, they'll run on + the main one by default. """ # Precision of the scheduler, evaluation of tasks to run will only happen @@ -157,7 +167,7 @@ class TaskScheduler: params: Optional[JsonMapping] = None, ) -> str: """Schedule a new potentially resumable task. A function matching the specified - `action` should have be registered with `register_action` before the task is run. + `action` should've been registered with `register_action` before the task is run. Args: action: the name of a previously registered action @@ -174,9 +184,10 @@ class TaskScheduler: The id of the scheduled task """ status = TaskStatus.SCHEDULED + start_now = False if timestamp is None or timestamp < self._clock.time_msec(): timestamp = self._clock.time_msec() - status = TaskStatus.ACTIVE + start_now = True task = ScheduledTask( random_string(16), @@ -190,9 +201,11 @@ class TaskScheduler: ) await self._store.insert_scheduled_task(task) - if status == TaskStatus.ACTIVE: + # If the task is ready to run immediately, run the scheduling algorithm now + # rather than waiting + if start_now: if self._run_background_tasks: - await self._launch_task(task) + self._launch_scheduled_tasks() else: self._hs.get_replication_command_handler().send_new_active_task(task.id) @@ -207,15 +220,15 @@ class TaskScheduler: result: Optional[JsonMapping] = None, error: Optional[str] = None, ) -> bool: - """Update some task associated values. This is exposed publicly so it can - be used inside task functions, mainly to update the result and be able to - resume a task at a specific step after a restart of synapse. + """Update some task-associated values. This is exposed publicly so it can + be used inside task functions, mainly to update the result or resume + a task at a specific step after a restart of synapse. It can also be used to stage a task, by setting the `status` to `SCHEDULED` with a new timestamp. - The `status` can only be set to `ACTIVE` or `SCHEDULED`, `COMPLETE` and `FAILED` - are terminal status and can only be set by returning it in the function. + The `status` can only be set to `ACTIVE` or `SCHEDULED`. `COMPLETE` and `FAILED` + are terminal statuses and can only be set by returning them from the function. Args: id: the id of the task to update @@ -223,6 +236,12 @@ class TaskScheduler: status: the new `TaskStatus` of the task result: the new result of the task error: the new error of the task + + Returns: + True if the update was successful, False otherwise. + + Raises: + Exception: If a status other than `ACTIVE` or `SCHEDULED` was passed. """ if status == TaskStatus.COMPLETE or status == TaskStatus.FAILED: raise Exception( @@ -260,9 +279,9 @@ class TaskScheduler: max_timestamp: Optional[int] = None, limit: Optional[int] = None, ) -> List[ScheduledTask]: - """Get a list of tasks. Returns all the tasks if no args is provided. + """Get a list of tasks. Returns all the tasks if no args are provided. - If an arg is `None` all tasks matching the other args will be selected. + If an arg is `None`, all tasks matching the other args will be selected. If an arg is an empty list, the corresponding value of the task needs to be `None` to be selected. @@ -274,8 +293,8 @@ class TaskScheduler: a timestamp inferior to the specified one limit: Only return `limit` number of rows if set. - Returns - A list of `ScheduledTask`, ordered by increasing timestamps + Returns: + A list of `ScheduledTask`, ordered by increasing timestamps. """ return await self._store.get_scheduled_tasks( actions=actions, @@ -300,23 +319,13 @@ class TaskScheduler: raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") await self._store.delete_scheduled_task(id) - def launch_task_by_id(self, id: str) -> None: - """Try launching the task with the given ID.""" - # Don't bother trying to launch new tasks if we're already at capacity. - if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: - return - - run_as_background_process("launch_task_by_id", self._launch_task_by_id, id) + def on_new_task(self, task_id: str) -> None: + """Handle a notification that a new ready-to-run task has been added to the queue""" + # Just run the scheduler + self._launch_scheduled_tasks() - async def _launch_task_by_id(self, id: str) -> None: - """Helper async function for `launch_task_by_id`.""" - task = await self.get_task(id) - if task: - await self._launch_task(task) - - @wrap_as_background_process("launch_scheduled_tasks") - async def _launch_scheduled_tasks(self) -> None: - """Retrieve and launch scheduled tasks that should be running at that time.""" + def _launch_scheduled_tasks(self) -> None: + """Retrieve and launch scheduled tasks that should be running at this time.""" # Don't bother trying to launch new tasks if we're already at capacity. if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: return @@ -326,20 +335,26 @@ class TaskScheduler: self._launching_new_tasks = True - try: - for task in await self.get_tasks( - statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS - ): - await self._launch_task(task) - for task in await self.get_tasks( - statuses=[TaskStatus.SCHEDULED], - max_timestamp=self._clock.time_msec(), - limit=self.MAX_CONCURRENT_RUNNING_TASKS, - ): - await self._launch_task(task) - - finally: - self._launching_new_tasks = False + async def inner() -> None: + try: + for task in await self.get_tasks( + statuses=[TaskStatus.ACTIVE], + limit=self.MAX_CONCURRENT_RUNNING_TASKS, + ): + # _launch_task will ignore tasks that we're already running, and + # will also do nothing if we're already at the maximum capacity. + await self._launch_task(task) + for task in await self.get_tasks( + statuses=[TaskStatus.SCHEDULED], + max_timestamp=self._clock.time_msec(), + limit=self.MAX_CONCURRENT_RUNNING_TASKS, + ): + await self._launch_task(task) + + finally: + self._launching_new_tasks = False + + run_as_background_process("launch_scheduled_tasks", inner) @wrap_as_background_process("clean_scheduled_tasks") async def _clean_scheduled_tasks(self) -> None: diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py deleted file mode 100644
index 5c9193e8a9..0000000000 --- a/synapse/util/threepids.py +++ /dev/null
@@ -1,123 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -import logging -import re -import typing - -if typing.TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -# it's unclear what the maximum length of an email address is. RFC3696 (as corrected -# by errata) says: -# the upper limit on address lengths should normally be considered to be 254. -# -# In practice, mail servers appear to be more tolerant and allow 400 characters -# or so. Let's allow 500, which should be plenty for everyone. -# -MAX_EMAIL_ADDRESS_LENGTH = 500 - - -async def check_3pid_allowed( - hs: "HomeServer", - medium: str, - address: str, - registration: bool = False, -) -> bool: - """Checks whether a given format of 3PID is allowed to be used on this HS - - Args: - hs: server - medium: 3pid medium - e.g. email, msisdn - address: address within that medium (e.g. "wotan@matrix.org") - msisdns need to first have been canonicalised - registration: whether we want to bind the 3PID as part of registering a new user. - - Returns: - whether the 3PID medium/address is allowed to be added to this HS - """ - if not await hs.get_password_auth_provider().is_3pid_allowed( - medium, address, registration - ): - return False - - if hs.config.registration.allowed_local_3pids: - for constraint in hs.config.registration.allowed_local_3pids: - logger.debug( - "Checking 3PID %s (%s) against %s (%s)", - address, - medium, - constraint["pattern"], - constraint["medium"], - ) - if medium == constraint["medium"] and re.match( - constraint["pattern"], address - ): - return True - else: - return True - - return False - - -def canonicalise_email(address: str) -> str: - """'Canonicalise' email address - Case folding of local part of email address and lowercase domain part - See MSC2265, https://github.com/matrix-org/matrix-doc/pull/2265 - - Args: - address: email address to be canonicalised - Returns: - The canonical form of the email address - Raises: - ValueError if the address could not be parsed. - """ - - address = address.strip() - - parts = address.split("@") - if len(parts) != 2: - logger.debug("Couldn't parse email address %s", address) - raise ValueError("Unable to parse email address") - - return parts[0].casefold() + "@" + parts[1].lower() - - -def validate_email(address: str) -> str: - """Does some basic validation on an email address. - - Returns the canonicalised email, as returned by `canonicalise_email`. - - Raises a ValueError if the email is invalid. - """ - # First we try canonicalising in case that fails - address = canonicalise_email(address) - - # Email addresses have to be at least 3 characters. - if len(address) < 3: - raise ValueError("Unable to parse email address") - - if len(address) > MAX_EMAIL_ADDRESS_LENGTH: - raise ValueError("Unable to parse email address") - - return address diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 44b109bdfd..95eb1d7185 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py
@@ -47,7 +47,6 @@ class WheelTimer(Generic[T]): """ self.bucket_size: int = bucket_size self.entries: List[_Entry[T]] = [] - self.current_tick: int = 0 def insert(self, now: int, obj: T, then: int) -> None: """Inserts object into timer. @@ -78,11 +77,10 @@ class WheelTimer(Generic[T]): self.entries[max(min_key, then_key) - min_key].elements.add(obj) return - next_key = now_key + 1 if self.entries: - last_key = self.entries[-1].end_key + last_key = self.entries[-1].end_key + 1 else: - last_key = next_key + last_key = now_key + 1 # Handle the case when `then` is in the past and `entries` is empty. then_key = max(last_key, then_key)