From cea1b58c4a5850a59e14808324ce1d9f222f52e5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 3 Mar 2022 12:47:55 +0000 Subject: Don't impose version checks on dev extras at runtime (#12129) * Fix incorrect argument in test case * Add copyright header * Docstring and __all__ * Exclude dev depenencies * Use changelog from #12088 * Include version in error messages This will hopefully distinguish between the version of the source code and the version of the distribution package that is installed. * Linter script is your friend --- synapse/util/check_dependencies.py | 61 +++++++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 7 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 3a1f6b3c75..39b0a91db3 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -1,3 +1,25 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module exposes a single function which checks synapse's dependencies are present +and correctly versioned. It makes use of `importlib.metadata` to do so. The details +are a bit murky: there's no easy way to get a map from "extras" to the packages they +require. But this is probably just symptomatic of Python's package management. +""" + import logging from typing import Iterable, NamedTuple, Optional @@ -10,6 +32,8 @@ try: except ImportError: import importlib_metadata as metadata # type: ignore[no-redef] +__all__ = ["check_requirements"] + class DependencyException(Exception): @property @@ -29,7 +53,17 @@ class DependencyException(Exception): yield '"' + i + '"' -EXTRAS = set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) +DEV_EXTRAS = {"lint", "mypy", "test", "dev"} +RUNTIME_EXTRAS = ( + set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) - DEV_EXTRAS +) +VERSION = metadata.version(DISTRIBUTION_NAME) + + +def _is_dev_dependency(req: Requirement) -> bool: + return req.marker is not None and any( + req.marker.evaluate({"extra": e}) for e in DEV_EXTRAS + ) class Dependency(NamedTuple): @@ -43,6 +77,9 @@ def _generic_dependencies() -> Iterable[Dependency]: assert requirements is not None for raw_requirement in requirements: req = Requirement(raw_requirement) + if _is_dev_dependency(req): + continue + # https://packaging.pypa.io/en/latest/markers.html#usage notes that # > Evaluating an extra marker with no environment is an error # so we pass in a dummy empty extra value here. @@ -56,6 +93,8 @@ def _dependencies_for_extra(extra: str) -> Iterable[Dependency]: assert requirements is not None for raw_requirement in requirements: req = Requirement(raw_requirement) + if _is_dev_dependency(req): + continue # Exclude mandatory deps by only selecting deps needed with this extra. if ( req.marker is not None @@ -67,18 +106,26 @@ def _dependencies_for_extra(extra: str) -> Iterable[Dependency]: def _not_installed(requirement: Requirement, extra: Optional[str] = None) -> str: if extra: - return f"Need {requirement.name} for {extra}, but it is not installed" + return ( + f"Synapse {VERSION} needs {requirement.name} for {extra}, " + f"but it is not installed" + ) else: - return f"Need {requirement.name}, but it is not installed" + return f"Synapse {VERSION} needs {requirement.name}, but it is not installed" def _incorrect_version( requirement: Requirement, got: str, extra: Optional[str] = None ) -> str: if extra: - return f"Need {requirement} for {extra}, but got {requirement.name}=={got}" + return ( + f"Synapse {VERSION} needs {requirement} for {extra}, " + f"but got {requirement.name}=={got}" + ) else: - return f"Need {requirement}, but got {requirement.name}=={got}" + return ( + f"Synapse {VERSION} needs {requirement}, but got {requirement.name}=={got}" + ) def check_requirements(extra: Optional[str] = None) -> None: @@ -100,10 +147,10 @@ def check_requirements(extra: Optional[str] = None) -> None: # First work out which dependencies are required, and which are optional. if extra is None: dependencies = _generic_dependencies() - elif extra in EXTRAS: + elif extra in RUNTIME_EXTRAS: dependencies = _dependencies_for_extra(extra) else: - raise ValueError(f"Synapse does not provide the feature '{extra}'") + raise ValueError(f"Synapse {VERSION} does not provide the feature '{extra}'") deps_unfulfilled = [] errors = [] -- cgit 1.5.1 From 75574726a766f09d955c05672d400c65cb341810 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 4 Mar 2022 15:37:02 +0000 Subject: Add type hints for `ObservableDeferred` attributes (#12159) Signed-off-by: Sean Quah --- changelog.d/12159.misc | 1 + synapse/util/async_helpers.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12159.misc (limited to 'synapse/util') diff --git a/changelog.d/12159.misc b/changelog.d/12159.misc new file mode 100644 index 0000000000..30500f2fd9 --- /dev/null +++ b/changelog.d/12159.misc @@ -0,0 +1 @@ +Add type hints for `ObservableDeferred` attributes. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 60c03a66fd..a9f67dcbac 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -40,7 +40,7 @@ from typing import ( ) import attr -from typing_extensions import ContextManager +from typing_extensions import ContextManager, Literal from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -96,6 +96,10 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): __slots__ = ["_deferred", "_observers", "_result"] + _deferred: "defer.Deferred[_T]" + _observers: Union[List["defer.Deferred[_T]"], Tuple[()]] + _result: Union[None, Tuple[Literal[True], _T], Tuple[Literal[False], Failure]] + def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) @@ -158,12 +162,14 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): effect the underlying deferred. """ if not self._result: + assert isinstance(self._observers, list) d: "defer.Deferred[_T]" = defer.Deferred() self._observers.append(d) return d + elif self._result[0]: + return defer.succeed(self._result[1]) else: - success, res = self._result - return defer.succeed(res) if success else defer.fail(res) + return defer.fail(self._result[1]) def observers(self) -> "Collection[defer.Deferred[_T]]": return self._observers @@ -175,6 +181,8 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): return self._result is not None and self._result[0] is True def get_result(self) -> Union[_T, Failure]: + if self._result is None: + raise ValueError(f"{self!r} has no result yet") return self._result[1] def __getattr__(self, name: str) -> Any: -- cgit 1.5.1 From 2eef234ae367657d4fe5cb0bef6bda67e97b7e4d Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 8 Mar 2022 10:47:28 +0000 Subject: Fix a bug introduced in 1.54.0rc1 which meant that Synapse would refuse to start if pre-release versions of dependencies were installed. (#12177) * Add failing test to characterise the regression #12176 * Permit pre-release versions of specified packages * Newsfile (bugfix) Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/12177.bugfix | 1 + synapse/util/check_dependencies.py | 3 ++- tests/util/test_check_dependencies.py | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12177.bugfix (limited to 'synapse/util') diff --git a/changelog.d/12177.bugfix b/changelog.d/12177.bugfix new file mode 100644 index 0000000000..3f2852f345 --- /dev/null +++ b/changelog.d/12177.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.54.0rc1 which meant that Synapse would refuse to start if pre-release versions of dependencies were installed. \ No newline at end of file diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 39b0a91db3..12cd804939 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -163,7 +163,8 @@ def check_requirements(extra: Optional[str] = None) -> None: deps_unfulfilled.append(requirement.name) errors.append(_not_installed(requirement, extra)) else: - if not requirement.specifier.contains(dist.version): + # We specify prereleases=True to allow prereleases such as RCs. + if not requirement.specifier.contains(dist.version, prereleases=True): deps_unfulfilled.append(requirement.name) errors.append(_incorrect_version(requirement, dist.version, extra)) diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py index a91c33272f..38e9f58ac6 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py @@ -27,7 +27,9 @@ class DummyDistribution(metadata.Distribution): old = DummyDistribution("0.1.2") +old_release_candidate = DummyDistribution("0.1.2rc3") new = DummyDistribution("1.2.3") +new_release_candidate = DummyDistribution("1.2.3rc4") # could probably use stdlib TestCase --- no need for twisted here @@ -110,3 +112,20 @@ class TestDependencyChecker(TestCase): with self.mock_installed_package(new): # should not raise check_requirements("cool-extra") + + def test_release_candidates_satisfy_dependency(self) -> None: + """ + Tests that release candidates count as far as satisfying a dependency + is concerned. + (Regression test, see #12176.) + """ + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(old_release_candidate): + self.assertRaises(DependencyException, check_requirements) + + with self.mock_installed_package(new_release_candidate): + # should not raise + check_requirements() -- cgit 1.5.1 From 690cb4f3b32938f5ced5590abe9429733040a129 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 13:07:41 -0500 Subject: Allow for ignoring some arguments when caching. (#12189) * `@cached` can now take an `uncached_args` which is an iterable of names to not use in the cache key. * Requires `@cached`, @cachedList` and `@lru_cache` to use keyword arguments for clarity. * Asserts that keyword-only arguments in cached functions are not accepted. (I tested this briefly and I don't believe this works properly.) --- changelog.d/12189.misc | 1 + synapse/storage/databases/main/events_worker.py | 4 +- synapse/util/caches/descriptors.py | 74 +++++++++++++++++----- tests/util/caches/test_descriptors.py | 84 ++++++++++++++++++++++++- 4 files changed, 142 insertions(+), 21 deletions(-) create mode 100644 changelog.d/12189.misc (limited to 'synapse/util') diff --git a/changelog.d/12189.misc b/changelog.d/12189.misc new file mode 100644 index 0000000000..015e808e63 --- /dev/null +++ b/changelog.d/12189.misc @@ -0,0 +1 @@ +Support skipping some arguments when generating cache keys. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 26784f755e..59454a47df 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1286,7 +1286,7 @@ class EventsWorkerStore(SQLBaseStore): ) return {eid for ((_rid, eid), have_event) in res.items() if have_event} - @cachedList("have_seen_event", "keys") + @cachedList(cached_method_name="have_seen_event", list_name="keys") async def _have_seen_events_dict( self, keys: Iterable[Tuple[str, str]] ) -> Dict[Tuple[str, str], bool]: @@ -1954,7 +1954,7 @@ class EventsWorkerStore(SQLBaseStore): get_event_id_for_timestamp_txn, ) - @cachedList("is_partial_state_event", list_name="event_ids") + @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids") async def get_partial_state_events( self, event_ids: Collection[str] ) -> Dict[str, bool]: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1cdead02f1..c3c5c16db9 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -20,6 +20,7 @@ from typing import ( Any, Awaitable, Callable, + Collection, Dict, Generic, Hashable, @@ -69,6 +70,7 @@ class _CacheDescriptorBase: self, orig: Callable[..., Any], num_args: Optional[int], + uncached_args: Optional[Collection[str]] = None, cache_context: bool = False, ): self.orig = orig @@ -76,6 +78,13 @@ class _CacheDescriptorBase: arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args + # There's no reason that keyword-only arguments couldn't be supported, + # but right now they're buggy so do not allow them. + if arg_spec.kwonlyargs: + raise ValueError( + "_CacheDescriptorBase does not support keyword-only arguments." + ) + if "cache_context" in all_args: if not cache_context: raise ValueError( @@ -88,6 +97,9 @@ class _CacheDescriptorBase: " named `cache_context`" ) + if num_args is not None and uncached_args is not None: + raise ValueError("Cannot provide both num_args and uncached_args") + if num_args is None: num_args = len(all_args) - 1 if cache_context: @@ -105,6 +117,12 @@ class _CacheDescriptorBase: # list of the names of the args used as the cache key self.arg_names = all_args[1 : num_args + 1] + # If there are args to not cache on, filter them out (and fix the size of num_args). + if uncached_args is not None: + include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names] + else: + include_arg_in_cache_key = [True] * len(self.arg_names) + # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value if arg_spec.defaults: @@ -119,8 +137,8 @@ class _CacheDescriptorBase: self.add_cache_context = cache_context - self.cache_key_builder = get_cache_key_builder( - self.arg_names, self.arg_defaults + self.cache_key_builder = _get_cache_key_builder( + self.arg_names, include_arg_in_cache_key, self.arg_defaults ) @@ -130,8 +148,7 @@ class _LruCachedFunction(Generic[F]): def lru_cache( - max_entries: int = 1000, - cache_context: bool = False, + *, max_entries: int = 1000, cache_context: bool = False ) -> Callable[[F], _LruCachedFunction[F]]: """A method decorator that applies a memoizing cache around the function. @@ -186,7 +203,9 @@ class LruCacheDescriptor(_CacheDescriptorBase): max_entries: int = 1000, cache_context: bool = False, ): - super().__init__(orig, num_args=None, cache_context=cache_context) + super().__init__( + orig, num_args=None, uncached_args=None, cache_context=cache_context + ) self.max_entries = max_entries def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: @@ -260,6 +279,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): num_args: number of positional arguments (excluding ``self`` and ``cache_context``) to use as cache keys. Defaults to all named args of the function. + uncached_args: a list of argument names to not use as the cache key. + (``self`` and ``cache_context`` are always ignored.) Cannot be used + with num_args. tree: cache_context: iterable: @@ -273,12 +295,18 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): orig: Callable[..., Any], max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, ): - super().__init__(orig, num_args=num_args, cache_context=cache_context) + super().__init__( + orig, + num_args=num_args, + uncached_args=uncached_args, + cache_context=cache_context, + ) if tree and self.num_args < 2: raise RuntimeError( @@ -369,7 +397,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): but including list_name) to use as cache keys. Defaults to all named args of the function. """ - super().__init__(orig, num_args=num_args) + super().__init__(orig, num_args=num_args, uncached_args=None) self.list_name = list_name @@ -530,8 +558,10 @@ class _CacheContext: def cached( + *, max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, @@ -541,6 +571,7 @@ def cached( orig, max_entries=max_entries, num_args=num_args, + uncached_args=uncached_args, tree=tree, cache_context=cache_context, iterable=iterable, @@ -551,7 +582,7 @@ def cached( def cachedList( - cached_method_name: str, list_name: str, num_args: Optional[int] = None + *, cached_method_name: str, list_name: str, num_args: Optional[int] = None ) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `CacheListDescriptor`. @@ -590,13 +621,16 @@ def cachedList( return cast(Callable[[F], _CachedFunction[F]], func) -def get_cache_key_builder( - param_names: Sequence[str], param_defaults: Mapping[str, Any] +def _get_cache_key_builder( + param_names: Sequence[str], + include_params: Sequence[bool], + param_defaults: Mapping[str, Any], ) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: """Construct a function which will build cache keys suitable for a cached function Args: param_names: list of formal parameter names for the cached function + include_params: list of bools of whether to include the parameter name in the cache key param_defaults: a mapping from parameter name to default value for that param Returns: @@ -608,6 +642,7 @@ def get_cache_key_builder( if len(param_names) == 1: nm = param_names[0] + assert include_params[0] is True def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: if nm in kwargs: @@ -620,13 +655,18 @@ def get_cache_key_builder( else: def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: - return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs)) + return tuple( + _get_cache_key_gen( + param_names, include_params, param_defaults, args, kwargs + ) + ) return get_cache_key def _get_cache_key_gen( param_names: Iterable[str], + include_params: Iterable[bool], param_defaults: Mapping[str, Any], args: Sequence[Any], kwargs: Mapping[str, Any], @@ -637,16 +677,18 @@ def _get_cache_key_gen( This is essentially the same operation as `inspect.getcallargs`, but optimised so that we don't need to inspect the target function for each call. """ - # We loop through each arg name, looking up if its in the `kwargs`, # otherwise using the next argument in `args`. If there are no more # args then we try looking the arg name up in the defaults. pos = 0 - for nm in param_names: + for nm, inc in zip(param_names, include_params): if nm in kwargs: - yield kwargs[nm] + if inc: + yield kwargs[nm] elif pos < len(args): - yield args[pos] + if inc: + yield args[pos] pos += 1 else: - yield param_defaults[nm] + if inc: + yield param_defaults[nm] diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 19741ffcda..6a4b17527a 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -141,6 +141,84 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(r, "chips") obj.mock.assert_not_called() + @defer.inlineCallbacks + def test_cache_uncached_args(self): + """ + Only the arguments not named in uncached_args should matter to the cache + + Note that this is identical to test_cache_num_args, but provides the + arguments differently. + """ + + class Cls: + # Note that it is important that this is not the last argument to + # test behaviour of skipping arguments properly. + @descriptors.cached(uncached_args=("arg2",)) + def fn(self, arg1, arg2, arg3): + return self.mock(arg1, arg2, arg3) + + def __init__(self): + self.mock = mock.Mock() + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, 2, 3) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2, 3) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(2, 3, 4) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(2, 3, 4) + obj.mock.reset_mock() + + # the two values should now be cached; we should be able to vary + # the second argument and still get the cached result. + r = yield obj.fn(1, 4, 3) + self.assertEqual(r, "fish") + r = yield obj.fn(2, 5, 4) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() + + @defer.inlineCallbacks + def test_cache_kwargs(self): + """Test that keyword arguments are treated properly""" + + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, kwarg1=2): + return self.mock(arg1, kwarg1=kwarg1) + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, kwarg1=2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(1, kwarg1=3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, kwarg1=3) + obj.mock.reset_mock() + + # the values should now be cached. + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + # We should be able to not provide kwarg1 and get the cached value back. + r = yield obj.fn(1) + self.assertEqual(r, "fish") + # Keyword arguments can be in any order. + r = yield obj.fn(kwarg1=2, arg1=1) + self.assertEqual(r, "fish") + obj.mock.assert_not_called() + def test_cache_with_sync_exception(self): """If the wrapped function throws synchronously, things should continue to work""" @@ -656,7 +734,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): assert current_context().name == "c1" # we want this to behave like an asynchronous function @@ -715,7 +793,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") def list_fn(self, args1) -> "Deferred[dict]": return self.mock(args1) @@ -758,7 +836,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): # we want this to behave like an asynchronous function await run_on_reactor() -- cgit 1.5.1 From 72e7f1c420b879a0a1ef1430771698b868693ab0 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 10 Mar 2022 15:53:23 +0000 Subject: Remove workaround introduced in Synapse v1.50.0rc1 for Mjolnir compatibility. Breaks compatibility with Mjolnir v1.3.1 and earlier. (#11700) --- changelog.d/11700.removal | 1 + docs/upgrade.md | 8 ++++++++ synapse/util/__init__.py | 7 ------- 3 files changed, 9 insertions(+), 7 deletions(-) create mode 100644 changelog.d/11700.removal (limited to 'synapse/util') diff --git a/changelog.d/11700.removal b/changelog.d/11700.removal new file mode 100644 index 0000000000..d3d3c48f0f --- /dev/null +++ b/changelog.d/11700.removal @@ -0,0 +1 @@ +Remove workaround introduced in Synapse 1.50.0 for Mjolnir compatibility. Breaks compatibility with Mjolnir 1.3.1 and earlier. diff --git a/docs/upgrade.md b/docs/upgrade.md index 0d0bb066ee..95005962dc 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -106,6 +106,14 @@ You will need to ensure `synctl` is on your `PATH`. automatically, though you might need to activate a virtual environment depending on how you installed Synapse. + +## Compatibility dropped for Mjolnir 1.3.1 and earlier + +Synapse v1.55.0 drops support for Mjolnir 1.3.1 and earlier. +If you use the Mjolnir module to moderate your homeserver, +please upgrade Mjolnir to version 1.3.2 or later before upgrading Synapse. + + # Upgrading to v1.54.0 ## Legacy structured logging configuration removal diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 58b4220ff3..d8046b7553 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -31,13 +31,6 @@ from synapse.logging import context if typing.TYPE_CHECKING: pass -# FIXME Mjolnir imports glob_to_regex from this file, but it was moved to -# matrix_common. -# As a temporary workaround, we import glob_to_regex here for -# compatibility with current versions of Mjolnir. -# See https://github.com/matrix-org/mjolnir/pull/174 -from matrix_common.regex import glob_to_regex # noqa - logger = logging.getLogger(__name__) -- cgit 1.5.1 From bc9dff1d9597251a15a15475cb8e8194b2d14910 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 11 Mar 2022 07:06:21 -0500 Subject: Remove unnecessary pass statements. (#12206) --- changelog.d/12206.misc | 1 + synapse/handlers/device.py | 2 -- synapse/handlers/presence.py | 2 -- synapse/http/matrixfederationclient.py | 2 -- synapse/http/server.py | 1 - synapse/rest/media/v1/_base.py | 1 - synapse/server.py | 1 - synapse/storage/databases/main/registration.py | 2 -- synapse/storage/schema/main/delta/30/as_users.py | 1 - synapse/util/caches/treecache.py | 2 -- tests/handlers/test_password_providers.py | 1 - 11 files changed, 1 insertion(+), 15 deletions(-) create mode 100644 changelog.d/12206.misc (limited to 'synapse/util') diff --git a/changelog.d/12206.misc b/changelog.d/12206.misc new file mode 100644 index 0000000000..df59bb56cd --- /dev/null +++ b/changelog.d/12206.misc @@ -0,0 +1 @@ +Remove unnecessary `pass` statements. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d90cb259a6..d5ccaa0c37 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -371,7 +371,6 @@ class DeviceHandler(DeviceWorkerHandler): log_kv( {"reason": "User doesn't have device id.", "device_id": device_id} ) - pass else: raise @@ -414,7 +413,6 @@ class DeviceHandler(DeviceWorkerHandler): # no match set_tag("error", True) set_tag("reason", "User doesn't have that device id.") - pass else: raise diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 9927a30e6e..34d9411bbf 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -267,7 +267,6 @@ class BasePresenceHandler(abc.ABC): is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ - pass async def update_external_syncs_clear(self, process_id: str) -> None: """Marks all users that had been marked as syncing by a given process @@ -277,7 +276,6 @@ class BasePresenceHandler(abc.ABC): This is a no-op when presence is handled by a different worker. """ - pass async def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: list diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 40bf1e06d6..6b98d865f5 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -120,7 +120,6 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC): """Called when response has finished streaming and the parser should return the final result (or error). """ - pass @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -601,7 +600,6 @@ class MatrixFederationHttpClient: response.code, response_phrase, ) - pass else: logger.info( "{%s} [%s] Got response headers: %d %s", diff --git a/synapse/http/server.py b/synapse/http/server.py index 09b4125489..31ca841889 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -233,7 +233,6 @@ class HttpServer(Protocol): servlet_classname (str): The name of the handler to be used in prometheus and opentracing logs. """ - pass class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 9b40fd8a6c..c35d42fab8 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -298,7 +298,6 @@ class Responder: Returns: Resolves once the response has finished being written """ - pass def __enter__(self) -> None: pass diff --git a/synapse/server.py b/synapse/server.py index 7741ff29dc..2fcf18a7a6 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -328,7 +328,6 @@ class HomeServer(metaclass=abc.ABCMeta): Does nothing in this base class; overridden in derived classes to start the appropriate listeners. """ - pass def setup_background_tasks(self) -> None: """ diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index dc6665237a..a698d10cc5 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -48,8 +48,6 @@ class ExternalIDReuseException(Exception): """Exception if writing an external id for a user fails, because this external id is given to an other user.""" - pass - @attr.s(frozen=True, slots=True, auto_attribs=True) class TokenLookupResult: diff --git a/synapse/storage/schema/main/delta/30/as_users.py b/synapse/storage/schema/main/delta/30/as_users.py index 22a7901e15..4b4b166e37 100644 --- a/synapse/storage/schema/main/delta/30/as_users.py +++ b/synapse/storage/schema/main/delta/30/as_users.py @@ -36,7 +36,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): config_files = config.appservice.app_service_config_files except AttributeError: logger.warning("Could not get app_service_config_files from config") - pass appservices = load_appservices(config.server.server_name, config_files) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 563845f867..e78305f787 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -22,8 +22,6 @@ class TreeCacheNode(dict): leaves. """ - pass - class TreeCache: """ diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 49d832de81..d401fda938 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -124,7 +124,6 @@ class PasswordCustomAuthProvider: ("m.login.password", ("password",)): self.check_auth, } ) - pass def check_auth(self, *args): return mock_password_provider.check_auth(*args) -- cgit 1.5.1 From e6a106fd5ebbf30a7c84f8ba09dc903d20213be3 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 11 Mar 2022 16:15:11 +0100 Subject: Implement a Jinja2 filter to extract localparts from email addresses (#12212) --- changelog.d/12212.feature | 1 + docs/sample_config.yaml | 3 ++- docs/templates.md | 7 +++++++ synapse/config/oidc.py | 3 ++- synapse/handlers/oidc.py | 6 ++++++ synapse/util/templates.py | 5 +++++ 6 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12212.feature (limited to 'synapse/util') diff --git a/changelog.d/12212.feature b/changelog.d/12212.feature new file mode 100644 index 0000000000..fe337ff990 --- /dev/null +++ b/changelog.d/12212.feature @@ -0,0 +1 @@ +Add a new Jinja2 template filter to extract the local part of an email address. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ef25a3175f..d634fd8ff5 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1948,7 +1948,8 @@ saml2_config: # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their # own username (see the documentation for the -# 'sso_auth_account_details.html' template). +# 'sso_auth_account_details.html' template). This template can +# use the 'localpart_from_email' filter. # # confirm_localpart: Whether to prompt the user to validate (or # change) the generated localpart (see the documentation for the diff --git a/docs/templates.md b/docs/templates.md index b251d05cb9..f87692a453 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -36,6 +36,13 @@ Turns a `mxc://` URL for media content into an HTTP(S) one using the homeserver' Example: `message.sender_avatar_url|mxc_to_http(32,32)` +```python +localpart_from_email(address: str) -> str +``` + +Returns the local part of an email address (e.g. `alice` in `alice@example.com`). + +Example: `user.email_address|localpart_from_email` ## Email templates diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index fc95912d9b..5d571651cb 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -183,7 +183,8 @@ class OIDCConfig(Config): # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their # own username (see the documentation for the - # 'sso_auth_account_details.html' template). + # 'sso_auth_account_details.html' template). This template can + # use the 'localpart_from_email' filter. # # confirm_localpart: Whether to prompt the user to validate (or # change) the generated localpart (see the documentation for the diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index d98659edc7..724b9cfcb4 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -45,6 +45,7 @@ from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import Clock, json_decoder from synapse.util.caches.cached_call import RetryOnExceptionCachedCall from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry +from synapse.util.templates import _localpart_from_email_filter if TYPE_CHECKING: from synapse.server import HomeServer @@ -1308,6 +1309,11 @@ def jinja_finalize(thing: Any) -> Any: env = Environment(finalize=jinja_finalize) +env.filters.update( + { + "localpart_from_email": _localpart_from_email_filter, + } +) @attr.s(slots=True, frozen=True, auto_attribs=True) diff --git a/synapse/util/templates.py b/synapse/util/templates.py index 12941065ca..fb758b7180 100644 --- a/synapse/util/templates.py +++ b/synapse/util/templates.py @@ -64,6 +64,7 @@ def build_jinja_env( { "format_ts": _format_ts_filter, "mxc_to_http": _create_mxc_to_http_filter(config.server.public_baseurl), + "localpart_from_email": _localpart_from_email_filter, } ) @@ -112,3 +113,7 @@ def _create_mxc_to_http_filter( def _format_ts_filter(value: int, format: str) -> str: return time.strftime(format, time.localtime(value / 1000)) + + +def _localpart_from_email_filter(address: str) -> str: + return address.rsplit("@", 1)[0] -- cgit 1.5.1 From 90b2327066d2343faa86c464a182b6f3c4422ecd Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 14 Mar 2022 17:52:15 +0000 Subject: Add `delay_cancellation` utility function (#12180) `delay_cancellation` behaves like `stop_cancellation`, except it delays `CancelledError`s until the original `Deferred` resolves. This is handy for unifying cleanup paths and ensuring that uncancelled coroutines don't use finished logcontexts. Signed-off-by: Sean Quah --- changelog.d/12180.misc | 1 + synapse/util/async_helpers.py | 48 +++++++++++++-- tests/util/test_async_helpers.py | 124 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 161 insertions(+), 12 deletions(-) create mode 100644 changelog.d/12180.misc (limited to 'synapse/util') diff --git a/changelog.d/12180.misc b/changelog.d/12180.misc new file mode 100644 index 0000000000..7a347352fd --- /dev/null +++ b/changelog.d/12180.misc @@ -0,0 +1 @@ +Add `delay_cancellation` utility function, which behaves like `stop_cancellation` but waits until the original `Deferred` resolves before raising a `CancelledError`. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a9f67dcbac..69c8c1baa9 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -686,12 +686,48 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": Synapse logcontext rules. Returns: - A new `Deferred`, which will contain the result of the original `Deferred`, - but will not propagate cancellation through to the original. When cancelled, - the new `Deferred` will fail with a `CancelledError` and will not follow the - Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap - the new `Deferred`. + A new `Deferred`, which will contain the result of the original `Deferred`. + The new `Deferred` will not propagate cancellation through to the original. + When cancelled, the new `Deferred` will fail with a `CancelledError`. + + The new `Deferred` will not follow the Synapse logcontext rules and should be + wrapped with `make_deferred_yieldable`. + """ + new_deferred: "defer.Deferred[T]" = defer.Deferred() + deferred.chainDeferred(new_deferred) + return new_deferred + + +def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Delay cancellation of a `Deferred` until it resolves. + + Has the same effect as `stop_cancellation`, but the returned `Deferred` will not + resolve with a `CancelledError` until the original `Deferred` resolves. + + Args: + deferred: The `Deferred` to protect against cancellation. May optionally follow + the Synapse logcontext rules. + + Returns: + A new `Deferred`, which will contain the result of the original `Deferred`. + The new `Deferred` will not propagate cancellation through to the original. + When cancelled, the new `Deferred` will wait until the original `Deferred` + resolves before failing with a `CancelledError`. + + The new `Deferred` will follow the Synapse logcontext rules if `deferred` + follows the Synapse logcontext rules. Otherwise the new `Deferred` should be + wrapped with `make_deferred_yieldable`. """ - new_deferred: defer.Deferred[T] = defer.Deferred() + + def handle_cancel(new_deferred: "defer.Deferred[T]") -> None: + # before the new deferred is cancelled, we `pause` it to stop the cancellation + # propagating. we then `unpause` it once the wrapped deferred completes, to + # propagate the exception. + new_deferred.pause() + new_deferred.errback(Failure(CancelledError())) + + deferred.addBoth(lambda _: new_deferred.unpause()) + + new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel) deferred.chainDeferred(new_deferred) return new_deferred diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index ff53ce114b..e5bc416de1 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -13,6 +13,8 @@ # limitations under the License. import traceback +from parameterized import parameterized_class + from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.internet.task import Clock @@ -23,10 +25,12 @@ from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, current_context, + make_deferred_yieldable, ) from synapse.util.async_helpers import ( ObservableDeferred, concurrently_execute, + delay_cancellation, stop_cancellation, timeout_deferred, ) @@ -313,13 +317,27 @@ class ConcurrentlyExecuteTest(TestCase): self.successResultOf(d2) -class StopCancellationTests(TestCase): - """Tests for the `stop_cancellation` function.""" +@parameterized_class( + ("wrapper",), + [("stop_cancellation",), ("delay_cancellation",)], +) +class CancellationWrapperTests(TestCase): + """Common tests for the `stop_cancellation` and `delay_cancellation` functions.""" + + wrapper: str + + def wrap_deferred(self, deferred: "Deferred[str]") -> "Deferred[str]": + if self.wrapper == "stop_cancellation": + return stop_cancellation(deferred) + elif self.wrapper == "delay_cancellation": + return delay_cancellation(deferred) + else: + raise ValueError(f"Unsupported wrapper type: {self.wrapper}") def test_succeed(self): """Test that the new `Deferred` receives the result.""" deferred: "Deferred[str]" = Deferred() - wrapper_deferred = stop_cancellation(deferred) + wrapper_deferred = self.wrap_deferred(deferred) # Success should propagate through. deferred.callback("success") @@ -329,7 +347,7 @@ class StopCancellationTests(TestCase): def test_failure(self): """Test that the new `Deferred` receives the `Failure`.""" deferred: "Deferred[str]" = Deferred() - wrapper_deferred = stop_cancellation(deferred) + wrapper_deferred = self.wrap_deferred(deferred) # Failure should propagate through. deferred.errback(ValueError("abc")) @@ -337,6 +355,10 @@ class StopCancellationTests(TestCase): self.failureResultOf(wrapper_deferred, ValueError) self.assertIsNone(deferred.result, "`Failure` was not consumed") + +class StopCancellationTests(TestCase): + """Tests for the `stop_cancellation` function.""" + def test_cancellation(self): """Test that cancellation of the new `Deferred` leaves the original running.""" deferred: "Deferred[str]" = Deferred() @@ -347,11 +369,101 @@ class StopCancellationTests(TestCase): self.assertTrue(wrapper_deferred.called) self.failureResultOf(wrapper_deferred, CancelledError) self.assertFalse( - deferred.called, "Original `Deferred` was unexpectedly cancelled." + deferred.called, "Original `Deferred` was unexpectedly cancelled" + ) + + # Now make the original `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + +class DelayCancellationTests(TestCase): + """Tests for the `delay_cancellation` function.""" + + def test_cancellation(self): + """Test that cancellation of the new `Deferred` waits for the original.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Cancel the new `Deferred`. + wrapper_deferred.cancel() + self.assertNoResult(wrapper_deferred) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled" + ) + + # Now make the original `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + # Now that the original `Deferred` has failed, we should get a `CancelledError`. + self.failureResultOf(wrapper_deferred, CancelledError) + + def test_suppresses_second_cancellation(self): + """Test that a second cancellation is suppressed. + + Identical to `test_cancellation` except the new `Deferred` is cancelled twice. + """ + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Cancel the new `Deferred`, twice. + wrapper_deferred.cancel() + wrapper_deferred.cancel() + self.assertNoResult(wrapper_deferred) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled" ) - # Now make the inner `Deferred` fail. + # Now make the original `Deferred` fail. # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed # in logs. deferred.errback(ValueError("abc")) self.assertIsNone(deferred.result, "`Failure` was not consumed") + + # Now that the original `Deferred` has failed, we should get a `CancelledError`. + self.failureResultOf(wrapper_deferred, CancelledError) + + def test_propagates_cancelled_error(self): + """Test that a `CancelledError` from the original `Deferred` gets propagated.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Fail the original `Deferred` with a `CancelledError`. + cancelled_error = CancelledError() + deferred.errback(cancelled_error) + + # The new `Deferred` should fail with exactly the same `CancelledError`. + self.assertTrue(wrapper_deferred.called) + self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value) + + def test_preserves_logcontext(self): + """Test that logging contexts are preserved.""" + blocking_d: "Deferred[None]" = Deferred() + + async def inner(): + await make_deferred_yieldable(blocking_d) + + async def outer(): + with LoggingContext("c") as c: + try: + await delay_cancellation(defer.ensureDeferred(inner())) + self.fail("`CancelledError` was not raised") + except CancelledError: + self.assertEqual(c, current_context()) + # Succeed with no error, unless the logging context is wrong. + + # Run and block inside `inner()`. + d = defer.ensureDeferred(outer()) + self.assertEqual(SENTINEL_CONTEXT, current_context()) + + d.cancel() + + # Now unblock. `outer()` will consume the `CancelledError` and check the + # logging context. + blocking_d.callback(None) + self.successResultOf(d) -- cgit 1.5.1 From 605d161d7d585847fd1bb98d14d5281daeac8e86 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 14 Mar 2022 18:49:07 +0000 Subject: Add cancellation support to `ReadWriteLock` (#12120) Also convert `ReadWriteLock` to use async context managers. Signed-off-by: Sean Quah --- changelog.d/12120.misc | 1 + synapse/handlers/pagination.py | 8 +- synapse/util/async_helpers.py | 71 ++++---- tests/util/test_rwlock.py | 395 +++++++++++++++++++++++++++++++++++------ 4 files changed, 382 insertions(+), 93 deletions(-) create mode 100644 changelog.d/12120.misc (limited to 'synapse/util') diff --git a/changelog.d/12120.misc b/changelog.d/12120.misc new file mode 100644 index 0000000000..3603096500 --- /dev/null +++ b/changelog.d/12120.misc @@ -0,0 +1 @@ +Add support for cancellation to `ReadWriteLock`. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 183fabcfc0..60059fec3e 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -350,7 +350,7 @@ class PaginationHandler: """ self._purges_in_progress_by_room.add(room_id) try: - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): await self.storage.purge_events.purge_history( room_id, token, delete_local_events ) @@ -406,7 +406,7 @@ class PaginationHandler: room_id: room to be purged force: set true to skip checking for joined users. """ - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) @@ -448,7 +448,7 @@ class PaginationHandler: room_token = from_token.room_key - with await self.pagination_lock.read(room_id): + async with self.pagination_lock.read(room_id): ( membership, member_event_id, @@ -615,7 +615,7 @@ class PaginationHandler: self._purges_in_progress_by_room.add(room_id) try: - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN self._delete_by_id[ delete_id diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 69c8c1baa9..6a8e844d63 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -18,9 +18,10 @@ import collections import inspect import itertools import logging -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from typing import ( Any, + AsyncIterator, Awaitable, Callable, Collection, @@ -40,7 +41,7 @@ from typing import ( ) import attr -from typing_extensions import ContextManager, Literal +from typing_extensions import AsyncContextManager, Literal from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -491,7 +492,7 @@ class ReadWriteLock: Example: - with await read_write_lock.read("test_key"): + async with read_write_lock.read("test_key"): # do some work """ @@ -514,22 +515,24 @@ class ReadWriteLock: # Latest writer queued self.key_to_current_writer: Dict[str, defer.Deferred] = {} - async def read(self, key: str) -> ContextManager: - new_defer: "defer.Deferred[None]" = defer.Deferred() + def read(self, key: str) -> AsyncContextManager: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: + new_defer: "defer.Deferred[None]" = defer.Deferred() - curr_readers = self.key_to_current_readers.setdefault(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) - curr_readers.add(new_defer) + curr_readers.add(new_defer) - # We wait for the latest writer to finish writing. We can safely ignore - # any existing readers... as they're readers. - if curr_writer: - await make_deferred_yieldable(curr_writer) - - @contextmanager - def _ctx_manager() -> Iterator[None]: try: + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + # May raise a `CancelledError` if the `Deferred` wrapping us is + # cancelled. The `Deferred` we are waiting on must not be cancelled, + # since we do not own it. + if curr_writer: + await make_deferred_yieldable(stop_cancellation(curr_writer)) yield finally: with PreserveLoggingContext(): @@ -538,29 +541,35 @@ class ReadWriteLock: return _ctx_manager() - async def write(self, key: str) -> ContextManager: - new_defer: "defer.Deferred[None]" = defer.Deferred() + def write(self, key: str) -> AsyncContextManager: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: + new_defer: "defer.Deferred[None]" = defer.Deferred() - curr_readers = self.key_to_current_readers.get(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) - # We wait on all latest readers and writer. - to_wait_on = list(curr_readers) - if curr_writer: - to_wait_on.append(curr_writer) + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) - # We can clear the list of current readers since the new writer waits - # for them to finish. - curr_readers.clear() - self.key_to_current_writer[key] = new_defer + # We can clear the list of current readers since `new_defer` waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer - await make_deferred_yieldable(defer.gatherResults(to_wait_on)) - - @contextmanager - def _ctx_manager() -> Iterator[None]: + to_wait_on_defer = defer.gatherResults(to_wait_on) try: + # Wait for all current readers and the latest writer to finish. + # May raise a `CancelledError` immediately after the wait if the + # `Deferred` wrapping us is cancelled. We must only release the lock + # once we have acquired it, hence the use of `delay_cancellation` + # rather than `stop_cancellation`. + await make_deferred_yieldable(delay_cancellation(to_wait_on_defer)) yield finally: + # Release the lock. with PreserveLoggingContext(): new_defer.callback(None) # `self.key_to_current_writer[key]` may be missing if there was another diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 0774625b85..0c84226197 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import AsyncContextManager, Callable, Sequence, Tuple + from twisted.internet import defer -from twisted.internet.defer import Deferred +from twisted.internet.defer import CancelledError, Deferred from synapse.util.async_helpers import ReadWriteLock @@ -21,87 +23,187 @@ from tests import unittest class ReadWriteLockTestCase(unittest.TestCase): - def _assert_called_before_not_after(self, lst, first_false): - for i, d in enumerate(lst[:first_false]): - self.assertTrue(d.called, msg="%d was unexpectedly false" % i) + def _start_reader_or_writer( + self, + read_or_write: Callable[[str], AsyncContextManager], + key: str, + return_value: str, + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a reader or writer which acquires the lock, blocks, then completes. + + Args: + read_or_write: A function returning a context manager for a lock. + Either a bound `ReadWriteLock.read` or `ReadWriteLock.write`. + key: The key to read or write. + return_value: A string that the reader or writer will resolve with when + done. + + Returns: + A tuple of three `Deferred`s: + * A `Deferred` that resolves with `return_value` once the reader or writer + completes successfully. + * A `Deferred` that resolves once the reader or writer acquires the lock. + * A `Deferred` that blocks the reader or writer. Must be resolved by the + caller to allow the reader or writer to release the lock and complete. + """ + acquired_d: "Deferred[None]" = Deferred() + unblock_d: "Deferred[None]" = Deferred() + + async def reader_or_writer(): + async with read_or_write(key): + acquired_d.callback(None) + await unblock_d + return return_value + + d = defer.ensureDeferred(reader_or_writer()) + return d, acquired_d, unblock_d + + def _start_blocking_reader( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a reader which acquires the lock, blocks, then releases the lock. + + See the docstring for `_start_reader_or_writer` for details about the arguments + and return values. + """ + return self._start_reader_or_writer(rwlock.read, key, return_value) + + def _start_blocking_writer( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a writer which acquires the lock, blocks, then releases the lock. + + See the docstring for `_start_reader_or_writer` for details about the arguments + and return values. + """ + return self._start_reader_or_writer(rwlock.write, key, return_value) + + def _start_nonblocking_reader( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]"]: + """Starts a reader which acquires the lock, then releases it immediately. + + See the docstring for `_start_reader_or_writer` for details about the arguments. + + Returns: + A tuple of two `Deferred`s: + * A `Deferred` that resolves with `return_value` once the reader completes + successfully. + * A `Deferred` that resolves once the reader acquires the lock. + """ + d, acquired_d, unblock_d = self._start_reader_or_writer( + rwlock.read, key, return_value + ) + unblock_d.callback(None) + return d, acquired_d + + def _start_nonblocking_writer( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]"]: + """Starts a writer which acquires the lock, then releases it immediately. + + See the docstring for `_start_reader_or_writer` for details about the arguments. + + Returns: + A tuple of two `Deferred`s: + * A `Deferred` that resolves with `return_value` once the writer completes + successfully. + * A `Deferred` that resolves once the writer acquires the lock. + """ + d, acquired_d, unblock_d = self._start_reader_or_writer( + rwlock.write, key, return_value + ) + unblock_d.callback(None) + return d, acquired_d + + def _assert_first_n_resolved( + self, deferreds: Sequence["defer.Deferred[None]"], n: int + ) -> None: + """Assert that exactly the first n `Deferred`s in the given list are resolved. - for i, d in enumerate(lst[first_false:]): + Args: + deferreds: The list of `Deferred`s to be checked. + n: The number of `Deferred`s at the start of `deferreds` that should be + resolved. + """ + for i, d in enumerate(deferreds[:n]): + self.assertTrue(d.called, msg="deferred %d was unexpectedly unresolved" % i) + + for i, d in enumerate(deferreds[n:]): self.assertFalse( - d.called, msg="%d was unexpectedly true" % (i + first_false) + d.called, msg="deferred %d was unexpectedly resolved" % (i + n) ) def test_rwlock(self): rwlock = ReadWriteLock() - - key = object() + key = "key" ds = [ - rwlock.read(key), # 0 - rwlock.read(key), # 1 - rwlock.write(key), # 2 - rwlock.write(key), # 3 - rwlock.read(key), # 4 - rwlock.read(key), # 5 - rwlock.write(key), # 6 + self._start_blocking_reader(rwlock, key, "0"), + self._start_blocking_reader(rwlock, key, "1"), + self._start_blocking_writer(rwlock, key, "2"), + self._start_blocking_writer(rwlock, key, "3"), + self._start_blocking_reader(rwlock, key, "4"), + self._start_blocking_reader(rwlock, key, "5"), + self._start_blocking_writer(rwlock, key, "6"), ] - ds = [defer.ensureDeferred(d) for d in ds] + # `Deferred`s that resolve when each reader or writer acquires the lock. + acquired_ds = [acquired_d for _, acquired_d, _ in ds] + # `Deferred`s that will trigger the release of locks when resolved. + release_ds = [release_d for _, _, release_d in ds] - self._assert_called_before_not_after(ds, 2) + # The first two readers should acquire their locks. + self._assert_first_n_resolved(acquired_ds, 2) - with ds[0].result: - self._assert_called_before_not_after(ds, 2) - self._assert_called_before_not_after(ds, 2) + # Release one of the read locks. The next writer should not acquire the lock, + # because there is another reader holding the lock. + self._assert_first_n_resolved(acquired_ds, 2) + release_ds[0].callback(None) + self._assert_first_n_resolved(acquired_ds, 2) - with ds[1].result: - self._assert_called_before_not_after(ds, 2) - self._assert_called_before_not_after(ds, 3) + # Release the other read lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 2) + release_ds[1].callback(None) + self._assert_first_n_resolved(acquired_ds, 3) - with ds[2].result: - self._assert_called_before_not_after(ds, 3) - self._assert_called_before_not_after(ds, 4) + # Release the write lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 3) + release_ds[2].callback(None) + self._assert_first_n_resolved(acquired_ds, 4) - with ds[3].result: - self._assert_called_before_not_after(ds, 4) - self._assert_called_before_not_after(ds, 6) + # Release the write lock. The next two readers should acquire locks. + self._assert_first_n_resolved(acquired_ds, 4) + release_ds[3].callback(None) + self._assert_first_n_resolved(acquired_ds, 6) - with ds[5].result: - self._assert_called_before_not_after(ds, 6) - self._assert_called_before_not_after(ds, 6) + # Release one of the read locks. The next writer should not acquire the lock, + # because there is another reader holding the lock. + self._assert_first_n_resolved(acquired_ds, 6) + release_ds[5].callback(None) + self._assert_first_n_resolved(acquired_ds, 6) - with ds[4].result: - self._assert_called_before_not_after(ds, 6) - self._assert_called_before_not_after(ds, 7) + # Release the other read lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 6) + release_ds[4].callback(None) + self._assert_first_n_resolved(acquired_ds, 7) - with ds[6].result: - pass + # Release the write lock. + release_ds[6].callback(None) - d = defer.ensureDeferred(rwlock.write(key)) - self.assertTrue(d.called) - with d.result: - pass + # Acquire and release the write and read locks one last time for good measure. + _, acquired_d = self._start_nonblocking_writer(rwlock, key, "last writer") + self.assertTrue(acquired_d.called) - d = defer.ensureDeferred(rwlock.read(key)) - self.assertTrue(d.called) - with d.result: - pass + _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader") + self.assertTrue(acquired_d.called) def test_lock_handoff_to_nonblocking_writer(self): """Test a writer handing the lock to another writer that completes instantly.""" rwlock = ReadWriteLock() key = "key" - unblock: "Deferred[None]" = Deferred() - - async def blocking_write(): - with await rwlock.write(key): - await unblock - - async def nonblocking_write(): - with await rwlock.write(key): - pass - - d1 = defer.ensureDeferred(blocking_write()) - d2 = defer.ensureDeferred(nonblocking_write()) + d1, _, unblock = self._start_blocking_writer(rwlock, key, "write 1 completed") + d2, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") self.assertFalse(d1.called) self.assertFalse(d2.called) @@ -111,5 +213,182 @@ class ReadWriteLockTestCase(unittest.TestCase): self.assertTrue(d2.called) # The `ReadWriteLock` should operate as normal. - d3 = defer.ensureDeferred(nonblocking_write()) + d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") self.assertTrue(d3.called) + + def test_cancellation_while_holding_read_lock(self): + """Test cancellation while holding a read lock. + + A waiting writer should be given the lock when the reader holding the lock is + cancelled. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A reader takes the lock and blocks. + reader_d, _, _ = self._start_blocking_reader(rwlock, key, "read completed") + + # 2. A writer waits for the reader to complete. + writer_d, _ = self._start_nonblocking_writer(rwlock, key, "write completed") + self.assertFalse(writer_d.called) + + # 3. The reader is cancelled. + reader_d.cancel() + self.failureResultOf(reader_d, CancelledError) + + # 4. The writer should take the lock and complete. + self.assertTrue( + writer_d.called, "Writer is stuck waiting for a cancelled reader" + ) + self.assertEqual("write completed", self.successResultOf(writer_d)) + + def test_cancellation_while_holding_write_lock(self): + """Test cancellation while holding a write lock. + + A waiting reader should be given the lock when the writer holding the lock is + cancelled. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A writer takes the lock and blocks. + writer_d, _, _ = self._start_blocking_writer(rwlock, key, "write completed") + + # 2. A reader waits for the writer to complete. + reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed") + self.assertFalse(reader_d.called) + + # 3. The writer is cancelled. + writer_d.cancel() + self.failureResultOf(writer_d, CancelledError) + + # 4. The reader should take the lock and complete. + self.assertTrue( + reader_d.called, "Reader is stuck waiting for a cancelled writer" + ) + self.assertEqual("read completed", self.successResultOf(reader_d)) + + def test_cancellation_while_waiting_for_read_lock(self): + """Test cancellation while waiting for a read lock. + + Tests that cancelling a waiting reader: + * does not cancel the writer it is waiting on + * does not cancel the next writer waiting on it + * does not allow the next writer to acquire the lock before an earlier writer + has finished + * does not keep the next writer waiting indefinitely + + These correspond to the asserts with explicit messages. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A writer takes the lock and blocks. + writer1_d, _, unblock_writer1 = self._start_blocking_writer( + rwlock, key, "write 1 completed" + ) + + # 2. A reader waits for the first writer to complete. + # This reader will be cancelled later. + reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed") + self.assertFalse(reader_d.called) + + # 3. A second writer waits for both the first writer and the reader to complete. + writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") + self.assertFalse(writer2_d.called) + + # 4. The waiting reader is cancelled. + # Neither of the writers should be cancelled. + # The second writer should still be waiting, but only on the first writer. + reader_d.cancel() + self.failureResultOf(reader_d, CancelledError) + self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled") + self.assertFalse( + writer2_d.called, + "Second writer was unexpectedly cancelled or given the lock before the " + "first writer finished", + ) + + # 5. Unblock the first writer, which should complete. + unblock_writer1.callback(None) + self.assertEqual("write 1 completed", self.successResultOf(writer1_d)) + + # 6. The second writer should take the lock and complete. + self.assertTrue( + writer2_d.called, "Second writer is stuck waiting for a cancelled reader" + ) + self.assertEqual("write 2 completed", self.successResultOf(writer2_d)) + + def test_cancellation_while_waiting_for_write_lock(self): + """Test cancellation while waiting for a write lock. + + Tests that cancelling a waiting writer: + * does not cancel the reader or writer it is waiting on + * does not cancel the next writer waiting on it + * does not allow the next writer to acquire the lock before an earlier reader + and writer have finished + * does not keep the next writer waiting indefinitely + + These correspond to the asserts with explicit messages. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A reader takes the lock and blocks. + reader_d, _, unblock_reader = self._start_blocking_reader( + rwlock, key, "read completed" + ) + + # 2. A writer waits for the reader to complete. + writer1_d, _, unblock_writer1 = self._start_blocking_writer( + rwlock, key, "write 1 completed" + ) + + # 3. A second writer waits for both the reader and first writer to complete. + # This writer will be cancelled later. + writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") + self.assertFalse(writer2_d.called) + + # 4. A third writer waits for the second writer to complete. + writer3_d, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") + self.assertFalse(writer3_d.called) + + # 5. The second writer is cancelled, but continues waiting for the lock. + # The reader, first writer and third writer should not be cancelled. + # The first writer should still be waiting on the reader. + # The third writer should still be waiting on the second writer. + writer2_d.cancel() + self.assertNoResult(writer2_d) + self.assertFalse(reader_d.called, "Reader was unexpectedly cancelled") + self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled") + self.assertFalse( + writer3_d.called, + "Third writer was unexpectedly cancelled or given the lock before the first " + "writer finished", + ) + + # 6. Unblock the reader, which should complete. + # The first writer should be given the lock and block. + # The third writer should still be waiting on the second writer. + unblock_reader.callback(None) + self.assertEqual("read completed", self.successResultOf(reader_d)) + self.assertNoResult(writer2_d) + self.assertFalse( + writer3_d.called, + "Third writer was unexpectedly given the lock before the first writer " + "finished", + ) + + # 7. Unblock the first writer, which should complete. + unblock_writer1.callback(None) + self.assertEqual("write 1 completed", self.successResultOf(writer1_d)) + + # 8. The second writer should take the lock and release it immediately, since it + # has been cancelled. + self.failureResultOf(writer2_d, CancelledError) + + # 9. The third writer should take the lock and complete. + self.assertTrue( + writer3_d.called, "Third writer is stuck waiting for a cancelled writer" + ) + self.assertEqual("write 3 completed", self.successResultOf(writer3_d)) -- cgit 1.5.1 From 2fcf4b3f6cd2a0be6597622664636d2219957c2a Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 14 Mar 2022 19:04:29 +0000 Subject: Add cancellation support to `@cached` and `@cachedList` decorators (#12183) These decorators mostly support cancellation already. Add cancellation tests and fix use of finished logging contexts by delaying cancellation, as suggested by @erikjohnston. Signed-off-by: Sean Quah --- changelog.d/12183.misc | 1 + synapse/util/caches/descriptors.py | 11 +++ tests/util/caches/test_descriptors.py | 147 +++++++++++++++++++++++++++++++++- 3 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12183.misc (limited to 'synapse/util') diff --git a/changelog.d/12183.misc b/changelog.d/12183.misc new file mode 100644 index 0000000000..dd441bb64f --- /dev/null +++ b/changelog.d/12183.misc @@ -0,0 +1 @@ +Add cancellation support to `@cached` and `@cachedList` decorators. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index c3c5c16db9..eda92d864d 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -41,6 +41,7 @@ from twisted.python.failure import Failure from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError +from synapse.util.async_helpers import delay_cancellation from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.lrucache import LruCache @@ -350,6 +351,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) ret = cache.set(cache_key, ret, callback=invalidate_callback) + # We started a new call to `self.orig`, so we must always wait for it to + # complete. Otherwise we might mark our current logging context as + # finished while `self.orig` is still using it in the background. + ret = delay_cancellation(ret) + return make_deferred_yieldable(ret) wrapped = cast(_CachedFunction, _wrapped) @@ -510,6 +516,11 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( lambda _: results, unwrapFirstError ) + if missing: + # We started a new call to `self.orig`, so we must always wait for it to + # complete. Otherwise we might mark our current logging context as + # finished while `self.orig` is still using it in the background. + d = delay_cancellation(d) return make_deferred_yieldable(d) else: return defer.succeed(results) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 6a4b17527a..48e616ac74 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -17,7 +17,7 @@ from typing import Set from unittest import mock from twisted.internet import defer, reactor -from twisted.internet.defer import Deferred +from twisted.internet.defer import CancelledError, Deferred from synapse.api.errors import SynapseError from synapse.logging.context import ( @@ -28,7 +28,7 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached, lru_cache +from synapse.util.caches.descriptors import cached, cachedList, lru_cache from tests import unittest from tests.test_utils import get_awaitable_result @@ -493,6 +493,74 @@ class DescriptorTestCase(unittest.TestCase): obj.invalidate() top_invalidate.assert_called_once() + def test_cancel(self): + """Test that cancelling a lookup does not cancel other lookups""" + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + @cached() + async def fn(self, arg1): + await complete_lookup + return str(arg1) + + obj = Cls() + + d1 = obj.fn(123) + d2 = obj.fn(123) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + # Cancel `d1`, which is the lookup that caused `fn` to run. + d1.cancel() + + # `d2` should complete normally. + complete_lookup.callback(None) + self.failureResultOf(d1, CancelledError) + self.assertEqual(d2.result, "123") + + def test_cancel_logcontexts(self): + """Test that cancellation does not break logcontexts. + + * The `CancelledError` must be raised with the correct logcontext. + * The inner lookup must not resume with a finished logcontext. + * The inner lookup must not restore a finished logcontext when done. + """ + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + inner_context_was_finished = False + + @cached() + async def fn(self, arg1): + await make_deferred_yieldable(complete_lookup) + self.inner_context_was_finished = current_context().finished + return str(arg1) + + obj = Cls() + + async def do_lookup(): + with LoggingContext("c1") as c1: + try: + await obj.fn(123) + self.fail("No CancelledError thrown") + except CancelledError: + self.assertEqual( + current_context(), + c1, + "CancelledError was not raised with the correct logcontext", + ) + # suppress the error and succeed + + d = defer.ensureDeferred(do_lookup()) + d.cancel() + + complete_lookup.callback(None) + self.successResultOf(d) + self.assertFalse( + obj.inner_context_was_finished, "Tried to restart a finished logcontext" + ) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + class CacheDecoratorTestCase(unittest.HomeserverTestCase): """More tests for @cached @@ -865,3 +933,78 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.fn.invalidate((10, 2)) invalidate0.assert_called_once() invalidate1.assert_called_once() + + def test_cancel(self): + """Test that cancelling a lookup does not cancel other lookups""" + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + @cached() + def fn(self, arg1): + pass + + @cachedList(cached_method_name="fn", list_name="args") + async def list_fn(self, args): + await complete_lookup + return {arg: str(arg) for arg in args} + + obj = Cls() + + d1 = obj.list_fn([123, 456]) + d2 = obj.list_fn([123, 456, 789]) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + d1.cancel() + + # `d2` should complete normally. + complete_lookup.callback(None) + self.failureResultOf(d1, CancelledError) + self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"}) + + def test_cancel_logcontexts(self): + """Test that cancellation does not break logcontexts. + + * The `CancelledError` must be raised with the correct logcontext. + * The inner lookup must not resume with a finished logcontext. + * The inner lookup must not restore a finished logcontext when done. + """ + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + inner_context_was_finished = False + + @cached() + def fn(self, arg1): + pass + + @cachedList(cached_method_name="fn", list_name="args") + async def list_fn(self, args): + await make_deferred_yieldable(complete_lookup) + self.inner_context_was_finished = current_context().finished + return {arg: str(arg) for arg in args} + + obj = Cls() + + async def do_lookup(): + with LoggingContext("c1") as c1: + try: + await obj.list_fn([123]) + self.fail("No CancelledError thrown") + except CancelledError: + self.assertEqual( + current_context(), + c1, + "CancelledError was not raised with the correct logcontext", + ) + # suppress the error and succeed + + d = defer.ensureDeferred(do_lookup()) + d.cancel() + + complete_lookup.callback(None) + self.successResultOf(d) + self.assertFalse( + obj.inner_context_was_finished, "Tried to restart a finished logcontext" + ) + self.assertEqual(current_context(), SENTINEL_CONTEXT) -- cgit 1.5.1 From bf9d549e3ad944e1e53a2ecc898640d690bf1eac Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 18 Mar 2022 19:03:46 +0000 Subject: Try to detect borked package installations. (#12244) * Try to detect borked package installations. Fixes #12223. Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/12244.misc | 1 + synapse/util/check_dependencies.py | 24 +++++++++++++++++++++++- tests/util/test_check_dependencies.py | 15 ++++++++++++++- 3 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12244.misc (limited to 'synapse/util') diff --git a/changelog.d/12244.misc b/changelog.d/12244.misc new file mode 100644 index 0000000000..950d48e4c6 --- /dev/null +++ b/changelog.d/12244.misc @@ -0,0 +1 @@ +Improve error message when dependencies check finds a broken installation. \ No newline at end of file diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 12cd804939..66f1da7502 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -128,6 +128,19 @@ def _incorrect_version( ) +def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) -> str: + if extra: + return ( + f"Synapse {VERSION} needs {requirement} for {extra}, " + f"but can't determine {requirement.name}'s version" + ) + else: + return ( + f"Synapse {VERSION} needs {requirement}, " + f"but can't determine {requirement.name}'s version" + ) + + def check_requirements(extra: Optional[str] = None) -> None: """Check Synapse's dependencies are present and correctly versioned. @@ -163,8 +176,17 @@ def check_requirements(extra: Optional[str] = None) -> None: deps_unfulfilled.append(requirement.name) errors.append(_not_installed(requirement, extra)) else: + if dist.version is None: + # This shouldn't happen---it suggests a borked virtualenv. (See #12223) + # Try to give a vaguely helpful error message anyway. + # Type-ignore: the annotations don't reflect reality: see + # https://github.com/python/typeshed/issues/7513 + # https://bugs.python.org/issue47060 + deps_unfulfilled.append(requirement.name) # type: ignore[unreachable] + errors.append(_no_reported_version(requirement, extra)) + # We specify prereleases=True to allow prereleases such as RCs. - if not requirement.specifier.contains(dist.version, prereleases=True): + elif not requirement.specifier.contains(dist.version, prereleases=True): deps_unfulfilled.append(requirement.name) errors.append(_incorrect_version(requirement, dist.version, extra)) diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py index 38e9f58ac6..5d1aa025d1 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py @@ -12,7 +12,7 @@ from tests.unittest import TestCase class DummyDistribution(metadata.Distribution): - def __init__(self, version: str): + def __init__(self, version: object): self._version = version @property @@ -30,6 +30,7 @@ old = DummyDistribution("0.1.2") old_release_candidate = DummyDistribution("0.1.2rc3") new = DummyDistribution("1.2.3") new_release_candidate = DummyDistribution("1.2.3rc4") +distribution_with_no_version = DummyDistribution(None) # could probably use stdlib TestCase --- no need for twisted here @@ -67,6 +68,18 @@ class TestDependencyChecker(TestCase): # should not raise check_requirements() + def test_version_reported_as_none(self) -> None: + """Complain if importlib.metadata.version() returns None. + + This shouldn't normally happen, but it was seen in the wild (#12223). + """ + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(distribution_with_no_version): + self.assertRaises(DependencyException, check_requirements) + def test_checks_ignore_dev_dependencies(self) -> None: """Bot generic and per-extra checks should ignore dev dependencies.""" with patch( -- cgit 1.5.1