diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index e27c5d298f..b91020117f 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -42,7 +42,7 @@ from typing import (
)
import attr
-from typing_extensions import AsyncContextManager, Literal
+from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec
from twisted.internet import defer
from twisted.internet.defer import CancelledError
@@ -237,9 +237,16 @@ async def concurrently_execute(
)
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
async def yieldable_gather_results(
- func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
-) -> List[T]:
+ func: Callable[Concatenate[T, P], Awaitable[R]],
+ iter: Iterable[T],
+ *args: P.args,
+ **kwargs: P.kwargs,
+) -> List[R]:
"""Executes the function with each argument concurrently.
Args:
@@ -255,7 +262,15 @@ async def yieldable_gather_results(
try:
return await make_deferred_yieldable(
defer.gatherResults(
- [run_in_background(func, item, *args, **kwargs) for item in iter],
+ # type-ignore: mypy reports two errors:
+ # error: Argument 1 to "run_in_background" has incompatible type
+ # "Callable[[T, **P], Awaitable[R]]"; expected
+ # "Callable[[T, **P], Awaitable[R]]" [arg-type]
+ # error: Argument 2 to "run_in_background" has incompatible type
+ # "T"; expected "[T, **P.args]" [arg-type]
+ # The former looks like a mypy bug, and the latter looks like a
+ # false positive.
+ [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type]
consumeErrors=True,
)
)
@@ -577,9 +592,6 @@ class ReadWriteLock:
return _ctx_manager()
-R = TypeVar("R")
-
-
def timeout_deferred(
deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime
) -> "defer.Deferred[_T]":
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 91837655f8..b580bdd0de 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,7 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Generic,
+ List,
+ Optional,
+ TypeVar,
+ Union,
+)
+
+from typing_extensions import ParamSpec
from twisted.internet import defer
@@ -75,7 +87,11 @@ class Distributor:
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
-class Signal:
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+class Signal(Generic[P]):
"""A Signal is a dispatch point that stores a list of callables as
observers of it.
@@ -87,16 +103,16 @@ class Signal:
def __init__(self, name: str):
self.name: str = name
- self.observers: List[Callable] = []
+ self.observers: List[Callable[P, Any]] = []
- def observe(self, observer: Callable) -> None:
+ def observe(self, observer: Callable[P, Any]) -> None:
"""Adds a new callable to the observer list which will be invoked by
the 'fire' method.
Each observer callable may return a Deferred."""
self.observers.append(observer)
- def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
+ def fire(self, *args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers.
@@ -104,7 +120,7 @@ class Signal:
Returns a Deferred that will complete when all the observers have
completed."""
- async def do(observer: Callable[..., Any]) -> Any:
+ async def do(observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]:
try:
return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
@@ -114,6 +130,7 @@ class Signal:
observer,
e,
)
+ return None
deferreds = [run_in_background(do, o) for o in self.observers]
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 98ee49af6e..bc3b4938ea 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -15,10 +15,10 @@
import logging
from functools import wraps
from types import TracebackType
-from typing import Any, Callable, Optional, Type, TypeVar, cast
+from typing import Awaitable, Callable, Optional, Type, TypeVar
from prometheus_client import Counter
-from typing_extensions import Protocol
+from typing_extensions import Concatenate, ParamSpec, Protocol
from synapse.logging.context import (
ContextResourceUsage,
@@ -72,16 +72,21 @@ in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge(
)
-T = TypeVar("T", bound=Callable[..., Any])
+P = ParamSpec("P")
+R = TypeVar("R")
class HasClock(Protocol):
clock: Clock
-def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
- """
- Used to decorate an async function with a `Measure` context manager.
+def measure_func(
+ name: Optional[str] = None,
+) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
+ """Decorate an async method with a `Measure` context manager.
+
+ The Measure is created using `self.clock`; it should only be used to decorate
+ methods in classes defining an instance-level `clock` attribute.
Usage:
@@ -97,18 +102,24 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
"""
- def wrapper(func: T) -> T:
+ def wrapper(
+ func: Callable[Concatenate[HasClock, P], Awaitable[R]]
+ ) -> Callable[P, Awaitable[R]]:
block_name = func.__name__ if name is None else name
@wraps(func)
- async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any:
+ async def measured_func(self: HasClock, *args: P.args, **kwargs: P.kwargs) -> R:
with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs)
return r
- return cast(T, measured_func)
+ # There are some shenanigans here, because we're decorating a method but
+ # explicitly making use of the `self` parameter. The key thing here is that the
+ # return type within the return type for `measure_func` itself describes how the
+ # decorated function will be called.
+ return measured_func # type: ignore[return-value]
- return wrapper
+ return wrapper # type: ignore[return-value]
class Measure:
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index dace68666c..f97f98a057 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -16,6 +16,8 @@ import functools
import sys
from typing import Any, Callable, Generator, List, TypeVar, cast
+from typing_extensions import ParamSpec
+
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
@@ -25,6 +27,7 @@ _already_patched = False
T = TypeVar("T")
+P = ParamSpec("P")
def do_patch() -> None:
@@ -41,13 +44,13 @@ def do_patch() -> None:
return
def new_inline_callbacks(
- f: Callable[..., Generator["Deferred[object]", object, T]]
- ) -> Callable[..., "Deferred[T]"]:
+ f: Callable[P, Generator["Deferred[object]", object, T]]
+ ) -> Callable[P, "Deferred[T]"]:
@functools.wraps(f)
- def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]":
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]":
start_context = current_context()
changes: List[str] = []
- orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks(
+ orig: Callable[P, "Deferred[T]"] = orig_inline_callbacks(
_check_yield_points(f, changes)
)
@@ -115,7 +118,7 @@ def do_patch() -> None:
def _check_yield_points(
- f: Callable[..., Generator["Deferred[object]", object, T]],
+ f: Callable[P, Generator["Deferred[object]", object, T]],
changes: List[str],
) -> Callable:
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
@@ -138,7 +141,7 @@ def _check_yield_points(
@functools.wraps(f)
def check_yield_points_inner(
- *args: Any, **kwargs: Any
+ *args: P.args, **kwargs: P.kwargs
) -> Generator["Deferred[object]", object, T]:
gen = f(*args, **kwargs)
|