diff options
author | David Robertson <davidr@element.io> | 2022-05-09 11:27:39 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-09 10:27:39 +0000 |
commit | fa0eab9c8e159b698a31fc7cfaafed643f47e284 (patch) | |
tree | 10b0b3d1c09fdf88b7c227be9976999878f2f377 /synapse/util | |
parent | Don't error on unknown receipt types (#12670) (diff) | |
download | synapse-fa0eab9c8e159b698a31fc7cfaafed643f47e284.tar.xz |
Use `ParamSpec` in a few places (#12667)
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/async_helpers.py | 26 | ||||
-rw-r--r-- | synapse/util/distributor.py | 29 | ||||
-rw-r--r-- | synapse/util/metrics.py | 31 | ||||
-rw-r--r-- | synapse/util/patch_inline_callbacks.py | 15 |
4 files changed, 72 insertions, 29 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index e27c5d298f..b91020117f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -42,7 +42,7 @@ from typing import ( ) import attr -from typing_extensions import AsyncContextManager, Literal +from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -237,9 +237,16 @@ async def concurrently_execute( ) +P = ParamSpec("P") +R = TypeVar("R") + + async def yieldable_gather_results( - func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any -) -> List[T]: + func: Callable[Concatenate[T, P], Awaitable[R]], + iter: Iterable[T], + *args: P.args, + **kwargs: P.kwargs, +) -> List[R]: """Executes the function with each argument concurrently. Args: @@ -255,7 +262,15 @@ async def yieldable_gather_results( try: return await make_deferred_yieldable( defer.gatherResults( - [run_in_background(func, item, *args, **kwargs) for item in iter], + # type-ignore: mypy reports two errors: + # error: Argument 1 to "run_in_background" has incompatible type + # "Callable[[T, **P], Awaitable[R]]"; expected + # "Callable[[T, **P], Awaitable[R]]" [arg-type] + # error: Argument 2 to "run_in_background" has incompatible type + # "T"; expected "[T, **P.args]" [arg-type] + # The former looks like a mypy bug, and the latter looks like a + # false positive. + [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type] consumeErrors=True, ) ) @@ -577,9 +592,6 @@ class ReadWriteLock: return _ctx_manager() -R = TypeVar("R") - - def timeout_deferred( deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime ) -> "defer.Deferred[_T]": diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 91837655f8..b580bdd0de 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,7 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generic, + List, + Optional, + TypeVar, + Union, +) + +from typing_extensions import ParamSpec from twisted.internet import defer @@ -75,7 +87,11 @@ class Distributor: run_as_background_process(name, self.signals[name].fire, *args, **kwargs) -class Signal: +P = ParamSpec("P") +R = TypeVar("R") + + +class Signal(Generic[P]): """A Signal is a dispatch point that stores a list of callables as observers of it. @@ -87,16 +103,16 @@ class Signal: def __init__(self, name: str): self.name: str = name - self.observers: List[Callable] = [] + self.observers: List[Callable[P, Any]] = [] - def observe(self, observer: Callable) -> None: + def observe(self, observer: Callable[P, Any]) -> None: """Adds a new callable to the observer list which will be invoked by the 'fire' method. Each observer callable may return a Deferred.""" self.observers.append(observer) - def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]": + def fire(self, *args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[List[Any]]": """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is not an error to fire a signal with no observers. @@ -104,7 +120,7 @@ class Signal: Returns a Deferred that will complete when all the observers have completed.""" - async def do(observer: Callable[..., Any]) -> Any: + async def do(observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]: try: return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: @@ -114,6 +130,7 @@ class Signal: observer, e, ) + return None deferreds = [run_in_background(do, o) for o in self.observers] diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 98ee49af6e..bc3b4938ea 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -15,10 +15,10 @@ import logging from functools import wraps from types import TracebackType -from typing import Any, Callable, Optional, Type, TypeVar, cast +from typing import Awaitable, Callable, Optional, Type, TypeVar from prometheus_client import Counter -from typing_extensions import Protocol +from typing_extensions import Concatenate, ParamSpec, Protocol from synapse.logging.context import ( ContextResourceUsage, @@ -72,16 +72,21 @@ in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge( ) -T = TypeVar("T", bound=Callable[..., Any]) +P = ParamSpec("P") +R = TypeVar("R") class HasClock(Protocol): clock: Clock -def measure_func(name: Optional[str] = None) -> Callable[[T], T]: - """ - Used to decorate an async function with a `Measure` context manager. +def measure_func( + name: Optional[str] = None, +) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: + """Decorate an async method with a `Measure` context manager. + + The Measure is created using `self.clock`; it should only be used to decorate + methods in classes defining an instance-level `clock` attribute. Usage: @@ -97,18 +102,24 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]: """ - def wrapper(func: T) -> T: + def wrapper( + func: Callable[Concatenate[HasClock, P], Awaitable[R]] + ) -> Callable[P, Awaitable[R]]: block_name = func.__name__ if name is None else name @wraps(func) - async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any: + async def measured_func(self: HasClock, *args: P.args, **kwargs: P.kwargs) -> R: with Measure(self.clock, block_name): r = await func(self, *args, **kwargs) return r - return cast(T, measured_func) + # There are some shenanigans here, because we're decorating a method but + # explicitly making use of the `self` parameter. The key thing here is that the + # return type within the return type for `measure_func` itself describes how the + # decorated function will be called. + return measured_func # type: ignore[return-value] - return wrapper + return wrapper # type: ignore[return-value] class Measure: diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index dace68666c..f97f98a057 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -16,6 +16,8 @@ import functools import sys from typing import Any, Callable, Generator, List, TypeVar, cast +from typing_extensions import ParamSpec + from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.python.failure import Failure @@ -25,6 +27,7 @@ _already_patched = False T = TypeVar("T") +P = ParamSpec("P") def do_patch() -> None: @@ -41,13 +44,13 @@ def do_patch() -> None: return def new_inline_callbacks( - f: Callable[..., Generator["Deferred[object]", object, T]] - ) -> Callable[..., "Deferred[T]"]: + f: Callable[P, Generator["Deferred[object]", object, T]] + ) -> Callable[P, "Deferred[T]"]: @functools.wraps(f) - def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]": + def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]": start_context = current_context() changes: List[str] = [] - orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks( + orig: Callable[P, "Deferred[T]"] = orig_inline_callbacks( _check_yield_points(f, changes) ) @@ -115,7 +118,7 @@ def do_patch() -> None: def _check_yield_points( - f: Callable[..., Generator["Deferred[object]", object, T]], + f: Callable[P, Generator["Deferred[object]", object, T]], changes: List[str], ) -> Callable: """Wraps a generator that is about to be passed to defer.inlineCallbacks @@ -138,7 +141,7 @@ def _check_yield_points( @functools.wraps(f) def check_yield_points_inner( - *args: Any, **kwargs: Any + *args: P.args, **kwargs: P.kwargs ) -> Generator["Deferred[object]", object, T]: gen = f(*args, **kwargs) |