diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 70139beef2..e596e1ed20 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -41,6 +41,7 @@ from typing import (
Hashable,
Iterable,
List,
+ Literal,
Optional,
Set,
Tuple,
@@ -51,7 +52,7 @@ from typing import (
)
import attr
-from typing_extensions import Concatenate, Literal, ParamSpec
+from typing_extensions import Concatenate, ParamSpec, Unpack
from twisted.internet import defer
from twisted.internet.defer import CancelledError
@@ -61,6 +62,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
+ run_coroutine_in_background,
run_in_background,
)
from synapse.util import Clock
@@ -344,6 +346,7 @@ T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
+T5 = TypeVar("T5")
@overload
@@ -402,6 +405,112 @@ def gather_results( # type: ignore[misc]
return deferred.addCallback(tuple)
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]]]],
+) -> Tuple[Optional[T1]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ Optional[Coroutine[Any, Any, T3]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2], Optional[T3]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ Optional[Coroutine[Any, Any, T3]],
+ Optional[Coroutine[Any, Any, T4]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ Optional[Coroutine[Any, Any, T3]],
+ Optional[Coroutine[Any, Any, T4]],
+ Optional[Coroutine[Any, Any, T5]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ...
+
+
+async def gather_optional_coroutines(
+ *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
+) -> Tuple[Optional[T1], ...]:
+ """Helper function that allows waiting on multiple coroutines at once.
+
+ The return value is a tuple of the return values of the coroutines in order.
+
+ If a `None` is passed instead of a coroutine, it will be ignored and a None
+ is returned in the tuple.
+
+ Note: For typechecking we need to have an explicit overload for each
+ distinct number of coroutines passed in. If you see type problems, it's
+ likely because you're using many arguments and you need to add a new
+ overload above.
+ """
+
+ try:
+ results = await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_coroutine_in_background(coroutine)
+ for coroutine in coroutines
+ if coroutine is not None
+ ],
+ consumeErrors=True,
+ )
+ )
+
+ results_iter = iter(results)
+ return tuple(
+ next(results_iter) if coroutine is not None else None
+ for coroutine in coroutines
+ )
+ except defer.FirstError as dfe:
+ # unwrap the error from defer.gatherResults.
+
+ # The raised exception's traceback only includes func() etc if
+ # the 'await' happens before the exception is thrown - ie if the failure
+ # happens *asynchronously* - otherwise Twisted throws away the traceback as it
+ # could be large.
+ #
+ # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
+ # we could throw Twisted into the fires of Mordor.
+
+ # suppress exception chaining, because the FirstError doesn't tell us anything
+ # very interesting.
+ assert isinstance(dfe.subFailure.value, BaseException)
+ raise dfe.subFailure.value from None
+
+
@attr.s(slots=True, auto_attribs=True)
class _LinearizerEntry:
# The number of things executing.
@@ -885,3 +994,46 @@ class AwakenableSleeper:
# Cancel the sleep if we were woken up
if call.active():
call.cancel()
+
+
+class DeferredEvent:
+ """Like threading.Event but for async code"""
+
+ def __init__(self, reactor: IReactorTime) -> None:
+ self._reactor = reactor
+ self._deferred: "defer.Deferred[None]" = defer.Deferred()
+
+ def set(self) -> None:
+ if not self._deferred.called:
+ self._deferred.callback(None)
+
+ def clear(self) -> None:
+ if self._deferred.called:
+ self._deferred = defer.Deferred()
+
+ def is_set(self) -> bool:
+ return self._deferred.called
+
+ async def wait(self, timeout_seconds: float) -> bool:
+ if self.is_set():
+ return True
+
+ # Create a deferred that gets called in N seconds
+ sleep_deferred: "defer.Deferred[None]" = defer.Deferred()
+ call = self._reactor.callLater(timeout_seconds, sleep_deferred.callback, None)
+
+ try:
+ await make_deferred_yieldable(
+ defer.DeferredList(
+ [sleep_deferred, self._deferred],
+ fireOnOneCallback=True,
+ fireOnOneErrback=True,
+ consumeErrors=True,
+ )
+ )
+ finally:
+ # Cancel the sleep if we were woken up
+ if call.active():
+ call.cancel()
+
+ return self.is_set()
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 1e6696332f..14bd3ba3b0 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -21,10 +21,19 @@
import enum
import logging
import threading
-from typing import Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
+from typing import (
+ Dict,
+ Generic,
+ Iterable,
+ Literal,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+)
import attr
-from typing_extensions import Literal
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 8017c031ee..3198fdd2ed 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -21,10 +21,9 @@
import logging
from collections import OrderedDict
-from typing import Any, Generic, Iterable, Optional, TypeVar, Union, overload
+from typing import Any, Generic, Iterable, Literal, Optional, TypeVar, Union, overload
import attr
-from typing_extensions import Literal
from twisted.internet import defer
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 481a1a621e..2e5efa3a52 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -34,6 +34,7 @@ from typing import (
Generic,
Iterable,
List,
+ Literal,
Optional,
Set,
Tuple,
@@ -44,8 +45,6 @@ from typing import (
overload,
)
-from typing_extensions import Literal
-
from twisted.internet import reactor
from twisted.internet.interfaces import IReactorTime
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 96b7ca83dc..54b99134b9 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -101,7 +101,13 @@ class ResponseCache(Generic[KV]):
used rather than trying to compute a new response.
"""
- def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
+ def __init__(
+ self,
+ clock: Clock,
+ name: str,
+ timeout_ms: float = 0,
+ enable_logging: bool = True,
+ ):
self._result_cache: Dict[KV, ResponseCacheEntry] = {}
self.clock = clock
@@ -109,6 +115,7 @@ class ResponseCache(Generic[KV]):
self._name = name
self._metrics = register_cache("response_cache", name, self, resizable=False)
+ self._enable_logging = enable_logging
def size(self) -> int:
return len(self._result_cache)
@@ -246,9 +253,12 @@ class ResponseCache(Generic[KV]):
"""
entry = self._get(key)
if not entry:
- logger.debug(
- "[%s]: no cached result for [%s], calculating new one", self._name, key
- )
+ if self._enable_logging:
+ logger.debug(
+ "[%s]: no cached result for [%s], calculating new one",
+ self._name,
+ key,
+ )
context = ResponseCacheContext(cache_key=key)
if cache_context:
kwargs["cache_context"] = context
@@ -269,12 +279,15 @@ class ResponseCache(Generic[KV]):
return await make_deferred_yieldable(entry.result.observe())
result = entry.result.observe()
- if result.called:
- logger.info("[%s]: using completed cached result for [%s]", self._name, key)
- else:
- logger.info(
- "[%s]: using incomplete cached result for [%s]", self._name, key
- )
+ if self._enable_logging:
+ if result.called:
+ logger.info(
+ "[%s]: using completed cached result for [%s]", self._name, key
+ )
+ else:
+ logger.info(
+ "[%s]: using incomplete cached result for [%s]", self._name, key
+ )
span_context = entry.opentracing_span_context
with start_active_span_follows_from(
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 16fcb00206..5ac8643eef 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -142,9 +142,9 @@ class StreamChangeCache:
"""
assert isinstance(stream_pos, int)
- # _cache is not valid at or before the earliest known stream position, so
+ # _cache is not valid before the earliest known stream position, so
# return that the entity has changed.
- if stream_pos <= self._earliest_known_stream_pos:
+ if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses()
return True
@@ -186,7 +186,7 @@ class StreamChangeCache:
This will be all entities if the given stream position is at or earlier
than the earliest known stream position.
"""
- if not self._cache or stream_pos <= self._earliest_known_stream_pos:
+ if not self._cache or stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses()
return set(entities)
@@ -238,9 +238,9 @@ class StreamChangeCache:
"""
assert isinstance(stream_pos, int)
- # _cache is not valid at or before the earliest known stream position, so
+ # _cache is not valid before the earliest known stream position, so
# return that an entity has changed.
- if stream_pos <= self._earliest_known_stream_pos:
+ if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses()
return True
@@ -270,9 +270,9 @@ class StreamChangeCache:
"""
assert isinstance(stream_pos, int)
- # _cache is not valid at or before the earliest known stream position, so
+ # _cache is not valid before the earliest known stream position, so
# return None to mark that it is unknown if an entity has changed.
- if stream_pos <= self._earliest_known_stream_pos:
+ if stream_pos < self._earliest_known_stream_pos:
return AllEntitiesChangedResult(None)
changed_entities: List[EntityType] = []
@@ -314,6 +314,15 @@ class StreamChangeCache:
self._entity_to_key[entity] = stream_pos
self._evict()
+ def all_entities_changed(self, stream_pos: int) -> None:
+ """
+ Mark all entities as changed. This is useful when the cache is invalidated and
+ there may be some potential change for all of the entities.
+ """
+ self._cache.clear()
+ self._entity_to_key.clear()
+ self._earliest_known_stream_pos = stream_pos
+
def _evict(self) -> None:
"""
Ensure the cache has not exceeded the maximum size.
diff --git a/synapse/util/events.py b/synapse/util/events.py
new file mode 100644
index 0000000000..ad9b946578
--- /dev/null
+++ b/synapse/util/events.py
@@ -0,0 +1,29 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+#
+
+from synapse.util.stringutils import random_string
+
+
+def generate_fake_event_id() -> str:
+ """
+ Generate an event ID from random ASCII characters.
+
+ This is primarily useful for generating fake event IDs in response to
+ requests from shadow-banned users.
+
+ Returns:
+ A string intended to look like an event ID, but with no actual meaning.
+ """
+ return "$" + random_string(43)
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index b73f690b88..0a6a30aab2 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -30,14 +30,13 @@ from typing import (
Iterator,
List,
Mapping,
+ Protocol,
Set,
Sized,
Tuple,
TypeVar,
)
-from typing_extensions import Protocol
-
T = TypeVar("T")
S = TypeVar("S", bound="_SelfSlice")
@@ -115,7 +114,7 @@ def sorted_topologically(
# This is implemented by Kahn's algorithm.
- degree_map = {node: 0 for node in nodes}
+ degree_map = dict.fromkeys(nodes, 0)
reverse_graph: Dict[T, Set[T]] = {}
for node, edges in graph.items():
@@ -165,7 +164,7 @@ def sorted_topologically_batched(
persisted.
"""
- degree_map = {node: 0 for node in nodes}
+ degree_map = dict.fromkeys(nodes, 0)
reverse_graph: Dict[T, Set[T]] = {}
for node, edges in graph.items():
diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py
index e9a5fff211..87f801c0cf 100644
--- a/synapse/util/linked_list.py
+++ b/synapse/util/linked_list.py
@@ -19,8 +19,7 @@
#
#
-"""A circular doubly linked list implementation.
-"""
+"""A circular doubly linked list implementation."""
import threading
from typing import Generic, Optional, Type, TypeVar
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index 84ae226207..6fa15543ec 100644
--- a/synapse/util/macaroons.py
+++ b/synapse/util/macaroons.py
@@ -22,12 +22,11 @@
"""Utilities for manipulating macaroons"""
-from typing import Callable, Optional
+from typing import Callable, Literal, Optional
import attr
import pymacaroons
from pymacaroons.exceptions import MacaroonVerificationFailedException
-from typing_extensions import Literal
from synapse.util import Clock, stringutils
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 517e79ce5f..6a389f7a7e 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -22,10 +22,19 @@
import logging
from functools import wraps
from types import TracebackType
-from typing import Awaitable, Callable, Dict, Generator, Optional, Type, TypeVar
+from typing import (
+ Awaitable,
+ Callable,
+ Dict,
+ Generator,
+ Optional,
+ Protocol,
+ Type,
+ TypeVar,
+)
from prometheus_client import CollectorRegistry, Counter, Metric
-from typing_extensions import Concatenate, ParamSpec, Protocol
+from typing_extensions import Concatenate, ParamSpec
from synapse.logging.context import (
ContextResourceUsage,
@@ -110,7 +119,7 @@ def measure_func(
"""
def wrapper(
- func: Callable[Concatenate[HasClock, P], Awaitable[R]]
+ func: Callable[Concatenate[HasClock, P], Awaitable[R]],
) -> Callable[P, Awaitable[R]]:
block_name = func.__name__ if name is None else name
diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
deleted file mode 100644
index b6a784f0bc..0000000000
--- a/synapse/util/msisdn.py
+++ /dev/null
@@ -1,51 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2017 Vector Creations Ltd
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-import phonenumbers
-
-from synapse.api.errors import SynapseError
-
-
-def phone_number_to_msisdn(country: str, number: str) -> str:
- """
- Takes an ISO-3166-1 2 letter country code and phone number and
- returns an msisdn representing the canonical version of that
- phone number.
-
- As an example, if `country` is "GB" and `number` is "7470674927", this
- function will return "447470674927".
-
- Args:
- country: ISO-3166-1 2 letter country code
- number: Phone number in a national or international format
-
- Returns:
- The canonical form of the phone number, as an msisdn.
- Raises:
- SynapseError if the number could not be parsed.
- """
- try:
- phoneNumber = phonenumbers.parse(number, country)
- except phonenumbers.NumberParseException:
- raise SynapseError(400, "Unable to parse phone number")
- return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[
- 1:
- ]
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 46dad32156..beea4d2888 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -50,7 +50,7 @@ def do_patch() -> None:
return
def new_inline_callbacks(
- f: Callable[P, Generator["Deferred[object]", object, T]]
+ f: Callable[P, Generator["Deferred[object]", object, T]],
) -> Callable[P, "Deferred[T]"]:
@functools.wraps(f)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]":
@@ -162,7 +162,7 @@ def _check_yield_points(
d = result.throwExceptionIntoGenerator(gen)
else:
d = gen.send(result)
- except (StopIteration, defer._DefGen_Return) as e:
+ except StopIteration as e:
if current_context() != expected_context:
# This happens when the context is lost sometime *after* the
# final yield and returning. E.g. we forgot to yield on a
@@ -183,7 +183,7 @@ def _check_yield_points(
)
)
changes.append(err)
- # The `StopIteration` or `_DefGen_Return` contains the return value from the
+ # The `StopIteration` contains the return value from the
# generator.
return cast(T, e.value)
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 8ead72bb7a..3f067b792c 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -103,7 +103,7 @@ _rate_limiter_instances_lock = threading.Lock()
def _get_counts_from_rate_limiter_instance(
- count_func: Callable[["FederationRateLimiter"], int]
+ count_func: Callable[["FederationRateLimiter"], int],
) -> Mapping[Tuple[str, ...], int]:
"""Returns a count of something (slept/rejected hosts) by (metrics_name)"""
# Cast to a list to prevent it changing while the Prometheus
diff --git a/synapse/util/rust.py b/synapse/util/rust.py
index 0e35d6d188..37f43459f1 100644
--- a/synapse/util/rust.py
+++ b/synapse/util/rust.py
@@ -19,9 +19,12 @@
#
#
+import json
import os
-import sys
+import urllib.parse
from hashlib import blake2b
+from importlib.metadata import Distribution, PackageNotFoundError
+from typing import Optional
import synapse
from synapse.synapse_rust import get_rust_file_digest
@@ -32,22 +35,17 @@ def check_rust_lib_up_to_date() -> None:
be rebuilt.
"""
- if not _dist_is_editable():
- return
-
- synapse_dir = os.path.dirname(synapse.__file__)
- synapse_root = os.path.abspath(os.path.join(synapse_dir, ".."))
-
- # Double check we've not gone into site-packages...
- if os.path.basename(synapse_root) == "site-packages":
- return
-
- # ... and it looks like the root of a python project.
- if not os.path.exists("pyproject.toml"):
- return
+ # Get the location of the editable install.
+ synapse_root = get_synapse_source_directory()
+ if synapse_root is None:
+ return None
# Get the hash of all Rust source files
- hash = _hash_rust_files_in_directory(os.path.join(synapse_root, "rust", "src"))
+ rust_path = os.path.join(synapse_root, "rust", "src")
+ if not os.path.exists(rust_path):
+ return None
+
+ hash = _hash_rust_files_in_directory(rust_path)
if hash != get_rust_file_digest():
raise Exception("Rust module outdated. Please rebuild using `poetry install`")
@@ -82,10 +80,55 @@ def _hash_rust_files_in_directory(directory: str) -> str:
return hasher.hexdigest()
-def _dist_is_editable() -> bool:
- """Is distribution an editable install?"""
- for path_item in sys.path:
- egg_link = os.path.join(path_item, "matrix-synapse.egg-link")
- if os.path.isfile(egg_link):
- return True
- return False
+def get_synapse_source_directory() -> Optional[str]:
+ """Try and find the source directory of synapse for editable installs (like
+ those used in development).
+
+ Returns None if not an editable install (or otherwise can't find the source
+ directory).
+ """
+
+ # Try and find the installed matrix-synapse package.
+ try:
+ package = Distribution.from_name("matrix-synapse")
+ except PackageNotFoundError:
+ # The package is not found, so it's not installed and so must be being
+ # pulled out from a local directory (usually the current one).
+ synapse_dir = os.path.dirname(synapse.__file__)
+ synapse_root = os.path.abspath(os.path.join(synapse_dir, ".."))
+
+ # Double check we've not gone into site-packages...
+ if os.path.basename(synapse_root) == "site-packages":
+ return None
+
+ # ... and it looks like the root of a python project.
+ if not os.path.exists("pyproject.toml"):
+ return None
+
+ return synapse_root
+
+ # Read the `direct_url.json` metadata for the package. This won't exist for
+ # packages installed via a repository/etc.
+ # c.f. https://packaging.python.org/en/latest/specifications/direct-url/
+ direct_url_json = package.read_text("direct_url.json")
+ if direct_url_json is None:
+ return None
+
+ # c.f. https://packaging.python.org/en/latest/specifications/direct-url/ for
+ # the format
+ direct_url_dict: dict = json.loads(direct_url_json)
+
+ # `url` must exist as a key, and point to where we fetched the repo from.
+ project_url = urllib.parse.urlparse(direct_url_dict["url"])
+
+ # If its not a local file then we must have built the rust libs either a)
+ # after we downloaded the package, or b) we built the download wheel.
+ if project_url.scheme != "file":
+ return None
+
+ # And finally if its not an editable install then the files can't have
+ # changed since we installed the package.
+ if not direct_url_dict.get("dir_info", {}).get("editable", False):
+ return None
+
+ return project_url.path
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 13ff54b669..32b5bc00c9 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -43,6 +43,14 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
#
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
+# https://spec.matrix.org/v1.13/appendices/#common-namespaced-identifier-grammar
+#
+# At least one character, less than or equal to 255 characters. Must start with
+# a-z, the rest is a-z, 0-9, -, _, or ..
+#
+# This doesn't check anything about validity of namespaces.
+NAMESPACED_GRAMMAR = re.compile(r"^[a-z][a-z0-9_.-]{0,254}$")
+
def random_string(length: int) -> str:
"""Generate a cryptographically secure string of random letters.
@@ -68,6 +76,10 @@ def is_ascii(s: bytes) -> bool:
return True
+def is_namedspaced_grammar(s: str) -> bool:
+ return bool(NAMESPACED_GRAMMAR.match(s))
+
+
def assert_valid_client_secret(client_secret: str) -> None:
"""Validate that a given string matches the client_secret defined by the spec"""
if (
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index 448960b297..4683d09cd7 100644
--- a/synapse/util/task_scheduler.py
+++ b/synapse/util/task_scheduler.py
@@ -46,33 +46,43 @@ logger = logging.getLogger(__name__)
class TaskScheduler:
"""
- This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background`
- to launch a background task, or Twisted `deferLater` if we want to do so later on.
-
- The problem with that is that the tasks will just stop and never be resumed if synapse
- is stopped for whatever reason.
-
- How this works:
- - A function mapped to a named action should first be registered with `register_action`.
- This function will be called when trying to resuming tasks after a synapse shutdown,
- so this registration should happen when synapse is initialised, NOT right before scheduling
- a task.
- - A task can then be launched using this named action with `schedule_task`. A `params` dict
- can be passed, and it will be available to the registered function when launched. This task
- can be launch either now-ish, or later on by giving a `timestamp` parameter.
-
- The function may call `update_task` at any time to update the `result` of the task,
- and this can be used to resume the task at a specific point and/or to convey a result to
- the code launching the task.
- You can also specify the `result` (and/or an `error`) when returning from the function.
-
- The reconciliation loop runs every minute, so this is not a precise scheduler.
- There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already
- full. In this regard, please take great care that scheduled tasks can actually finished.
- For now there is no mechanism to stop a running task if it is stuck.
-
- Tasks will be run on the worker specified with `run_background_tasks_on` config,
- or the main one by default.
+ This is a simple task scheduler designed for resumable tasks. Normally,
+ you'd use `run_in_background` to start a background task or Twisted's
+ `deferLater` if you want to run it later.
+
+ The issue is that these tasks stop completely and won't resume if Synapse is
+ shut down for any reason.
+
+ Here's how it works:
+
+ - Register an Action: First, you need to register a function to a named
+ action using `register_action`. This function will be called to resume tasks
+ after a Synapse shutdown. Make sure to register it when Synapse initializes,
+ not right before scheduling the task.
+
+ - Schedule a Task: You can launch a task linked to the named action
+ using `schedule_task`. You can pass a `params` dictionary, which will be
+ passed to the registered function when it's executed. Tasks can be scheduled
+ to run either immediately or later by specifying a `timestamp`.
+
+ - Update Task: The function handling the task can call `update_task` at
+ any point to update the task's `result`. This lets you resume the task from
+ a specific point or pass results back to the code that scheduled it. When
+ the function completes, you can also return a `result` or an `error`.
+
+ Things to keep in mind:
+
+ - The reconciliation loop runs every minute, so this is not a high-precision
+ scheduler.
+
+ - Only 10 tasks can run at the same time. If the pool is full, tasks may be
+ delayed. Make sure your scheduled tasks can actually finish.
+
+ - Currently, there's no way to stop a task if it gets stuck.
+
+ - Tasks will run on the worker defined by the `run_background_tasks_on`
+ setting in your configuration. If no worker is specified, they'll run on
+ the main one by default.
"""
# Precision of the scheduler, evaluation of tasks to run will only happen
@@ -157,7 +167,7 @@ class TaskScheduler:
params: Optional[JsonMapping] = None,
) -> str:
"""Schedule a new potentially resumable task. A function matching the specified
- `action` should have be registered with `register_action` before the task is run.
+ `action` should've been registered with `register_action` before the task is run.
Args:
action: the name of a previously registered action
@@ -174,9 +184,10 @@ class TaskScheduler:
The id of the scheduled task
"""
status = TaskStatus.SCHEDULED
+ start_now = False
if timestamp is None or timestamp < self._clock.time_msec():
timestamp = self._clock.time_msec()
- status = TaskStatus.ACTIVE
+ start_now = True
task = ScheduledTask(
random_string(16),
@@ -190,9 +201,11 @@ class TaskScheduler:
)
await self._store.insert_scheduled_task(task)
- if status == TaskStatus.ACTIVE:
+ # If the task is ready to run immediately, run the scheduling algorithm now
+ # rather than waiting
+ if start_now:
if self._run_background_tasks:
- await self._launch_task(task)
+ self._launch_scheduled_tasks()
else:
self._hs.get_replication_command_handler().send_new_active_task(task.id)
@@ -207,15 +220,15 @@ class TaskScheduler:
result: Optional[JsonMapping] = None,
error: Optional[str] = None,
) -> bool:
- """Update some task associated values. This is exposed publicly so it can
- be used inside task functions, mainly to update the result and be able to
- resume a task at a specific step after a restart of synapse.
+ """Update some task-associated values. This is exposed publicly so it can
+ be used inside task functions, mainly to update the result or resume
+ a task at a specific step after a restart of synapse.
It can also be used to stage a task, by setting the `status` to `SCHEDULED` with
a new timestamp.
- The `status` can only be set to `ACTIVE` or `SCHEDULED`, `COMPLETE` and `FAILED`
- are terminal status and can only be set by returning it in the function.
+ The `status` can only be set to `ACTIVE` or `SCHEDULED`. `COMPLETE` and `FAILED`
+ are terminal statuses and can only be set by returning them from the function.
Args:
id: the id of the task to update
@@ -223,6 +236,12 @@ class TaskScheduler:
status: the new `TaskStatus` of the task
result: the new result of the task
error: the new error of the task
+
+ Returns:
+ True if the update was successful, False otherwise.
+
+ Raises:
+ Exception: If a status other than `ACTIVE` or `SCHEDULED` was passed.
"""
if status == TaskStatus.COMPLETE or status == TaskStatus.FAILED:
raise Exception(
@@ -260,9 +279,9 @@ class TaskScheduler:
max_timestamp: Optional[int] = None,
limit: Optional[int] = None,
) -> List[ScheduledTask]:
- """Get a list of tasks. Returns all the tasks if no args is provided.
+ """Get a list of tasks. Returns all the tasks if no args are provided.
- If an arg is `None` all tasks matching the other args will be selected.
+ If an arg is `None`, all tasks matching the other args will be selected.
If an arg is an empty list, the corresponding value of the task needs
to be `None` to be selected.
@@ -274,8 +293,8 @@ class TaskScheduler:
a timestamp inferior to the specified one
limit: Only return `limit` number of rows if set.
- Returns
- A list of `ScheduledTask`, ordered by increasing timestamps
+ Returns:
+ A list of `ScheduledTask`, ordered by increasing timestamps.
"""
return await self._store.get_scheduled_tasks(
actions=actions,
@@ -300,23 +319,13 @@ class TaskScheduler:
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
await self._store.delete_scheduled_task(id)
- def launch_task_by_id(self, id: str) -> None:
- """Try launching the task with the given ID."""
- # Don't bother trying to launch new tasks if we're already at capacity.
- if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
- return
-
- run_as_background_process("launch_task_by_id", self._launch_task_by_id, id)
+ def on_new_task(self, task_id: str) -> None:
+ """Handle a notification that a new ready-to-run task has been added to the queue"""
+ # Just run the scheduler
+ self._launch_scheduled_tasks()
- async def _launch_task_by_id(self, id: str) -> None:
- """Helper async function for `launch_task_by_id`."""
- task = await self.get_task(id)
- if task:
- await self._launch_task(task)
-
- @wrap_as_background_process("launch_scheduled_tasks")
- async def _launch_scheduled_tasks(self) -> None:
- """Retrieve and launch scheduled tasks that should be running at that time."""
+ def _launch_scheduled_tasks(self) -> None:
+ """Retrieve and launch scheduled tasks that should be running at this time."""
# Don't bother trying to launch new tasks if we're already at capacity.
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return
@@ -326,20 +335,26 @@ class TaskScheduler:
self._launching_new_tasks = True
- try:
- for task in await self.get_tasks(
- statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS
- ):
- await self._launch_task(task)
- for task in await self.get_tasks(
- statuses=[TaskStatus.SCHEDULED],
- max_timestamp=self._clock.time_msec(),
- limit=self.MAX_CONCURRENT_RUNNING_TASKS,
- ):
- await self._launch_task(task)
-
- finally:
- self._launching_new_tasks = False
+ async def inner() -> None:
+ try:
+ for task in await self.get_tasks(
+ statuses=[TaskStatus.ACTIVE],
+ limit=self.MAX_CONCURRENT_RUNNING_TASKS,
+ ):
+ # _launch_task will ignore tasks that we're already running, and
+ # will also do nothing if we're already at the maximum capacity.
+ await self._launch_task(task)
+ for task in await self.get_tasks(
+ statuses=[TaskStatus.SCHEDULED],
+ max_timestamp=self._clock.time_msec(),
+ limit=self.MAX_CONCURRENT_RUNNING_TASKS,
+ ):
+ await self._launch_task(task)
+
+ finally:
+ self._launching_new_tasks = False
+
+ run_as_background_process("launch_scheduled_tasks", inner)
@wrap_as_background_process("clean_scheduled_tasks")
async def _clean_scheduled_tasks(self) -> None:
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
deleted file mode 100644
index 5c9193e8a9..0000000000
--- a/synapse/util/threepids.py
+++ /dev/null
@@ -1,123 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-import logging
-import re
-import typing
-
-if typing.TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-# it's unclear what the maximum length of an email address is. RFC3696 (as corrected
-# by errata) says:
-# the upper limit on address lengths should normally be considered to be 254.
-#
-# In practice, mail servers appear to be more tolerant and allow 400 characters
-# or so. Let's allow 500, which should be plenty for everyone.
-#
-MAX_EMAIL_ADDRESS_LENGTH = 500
-
-
-async def check_3pid_allowed(
- hs: "HomeServer",
- medium: str,
- address: str,
- registration: bool = False,
-) -> bool:
- """Checks whether a given format of 3PID is allowed to be used on this HS
-
- Args:
- hs: server
- medium: 3pid medium - e.g. email, msisdn
- address: address within that medium (e.g. "wotan@matrix.org")
- msisdns need to first have been canonicalised
- registration: whether we want to bind the 3PID as part of registering a new user.
-
- Returns:
- whether the 3PID medium/address is allowed to be added to this HS
- """
- if not await hs.get_password_auth_provider().is_3pid_allowed(
- medium, address, registration
- ):
- return False
-
- if hs.config.registration.allowed_local_3pids:
- for constraint in hs.config.registration.allowed_local_3pids:
- logger.debug(
- "Checking 3PID %s (%s) against %s (%s)",
- address,
- medium,
- constraint["pattern"],
- constraint["medium"],
- )
- if medium == constraint["medium"] and re.match(
- constraint["pattern"], address
- ):
- return True
- else:
- return True
-
- return False
-
-
-def canonicalise_email(address: str) -> str:
- """'Canonicalise' email address
- Case folding of local part of email address and lowercase domain part
- See MSC2265, https://github.com/matrix-org/matrix-doc/pull/2265
-
- Args:
- address: email address to be canonicalised
- Returns:
- The canonical form of the email address
- Raises:
- ValueError if the address could not be parsed.
- """
-
- address = address.strip()
-
- parts = address.split("@")
- if len(parts) != 2:
- logger.debug("Couldn't parse email address %s", address)
- raise ValueError("Unable to parse email address")
-
- return parts[0].casefold() + "@" + parts[1].lower()
-
-
-def validate_email(address: str) -> str:
- """Does some basic validation on an email address.
-
- Returns the canonicalised email, as returned by `canonicalise_email`.
-
- Raises a ValueError if the email is invalid.
- """
- # First we try canonicalising in case that fails
- address = canonicalise_email(address)
-
- # Email addresses have to be at least 3 characters.
- if len(address) < 3:
- raise ValueError("Unable to parse email address")
-
- if len(address) > MAX_EMAIL_ADDRESS_LENGTH:
- raise ValueError("Unable to parse email address")
-
- return address
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 44b109bdfd..95eb1d7185 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -47,7 +47,6 @@ class WheelTimer(Generic[T]):
"""
self.bucket_size: int = bucket_size
self.entries: List[_Entry[T]] = []
- self.current_tick: int = 0
def insert(self, now: int, obj: T, then: int) -> None:
"""Inserts object into timer.
@@ -78,11 +77,10 @@ class WheelTimer(Generic[T]):
self.entries[max(min_key, then_key) - min_key].elements.add(obj)
return
- next_key = now_key + 1
if self.entries:
- last_key = self.entries[-1].end_key
+ last_key = self.entries[-1].end_key + 1
else:
- last_key = next_key
+ last_key = now_key + 1
# Handle the case when `then` is in the past and `entries` is empty.
then_key = max(last_key, then_key)
|