diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 91837655f8..b580bdd0de 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,7 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Generic,
+ List,
+ Optional,
+ TypeVar,
+ Union,
+)
+
+from typing_extensions import ParamSpec
from twisted.internet import defer
@@ -75,7 +87,11 @@ class Distributor:
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
-class Signal:
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+class Signal(Generic[P]):
"""A Signal is a dispatch point that stores a list of callables as
observers of it.
@@ -87,16 +103,16 @@ class Signal:
def __init__(self, name: str):
self.name: str = name
- self.observers: List[Callable] = []
+ self.observers: List[Callable[P, Any]] = []
- def observe(self, observer: Callable) -> None:
+ def observe(self, observer: Callable[P, Any]) -> None:
"""Adds a new callable to the observer list which will be invoked by
the 'fire' method.
Each observer callable may return a Deferred."""
self.observers.append(observer)
- def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
+ def fire(self, *args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers.
@@ -104,7 +120,7 @@ class Signal:
Returns a Deferred that will complete when all the observers have
completed."""
- async def do(observer: Callable[..., Any]) -> Any:
+ async def do(observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]:
try:
return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
@@ -114,6 +130,7 @@ class Signal:
observer,
e,
)
+ return None
deferreds = [run_in_background(do, o) for o in self.observers]
|