summary refs log tree commit diff
path: root/synapse/util/distributor.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/distributor.py')
-rw-r--r--synapse/util/distributor.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 91837655f8..b580bdd0de 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,7 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, Dict, List
+from typing import (
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Generic,
+    List,
+    Optional,
+    TypeVar,
+    Union,
+)
+
+from typing_extensions import ParamSpec
 
 from twisted.internet import defer
 
@@ -75,7 +87,11 @@ class Distributor:
         run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
 
 
-class Signal:
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+class Signal(Generic[P]):
     """A Signal is a dispatch point that stores a list of callables as
     observers of it.
 
@@ -87,16 +103,16 @@ class Signal:
 
     def __init__(self, name: str):
         self.name: str = name
-        self.observers: List[Callable] = []
+        self.observers: List[Callable[P, Any]] = []
 
-    def observe(self, observer: Callable) -> None:
+    def observe(self, observer: Callable[P, Any]) -> None:
         """Adds a new callable to the observer list which will be invoked by
         the 'fire' method.
 
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
-    def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
+    def fire(self, *args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[List[Any]]":
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
         not an error to fire a signal with no observers.
@@ -104,7 +120,7 @@ class Signal:
         Returns a Deferred that will complete when all the observers have
         completed."""
 
-        async def do(observer: Callable[..., Any]) -> Any:
+        async def do(observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]:
             try:
                 return await maybe_awaitable(observer(*args, **kwargs))
             except Exception as e:
@@ -114,6 +130,7 @@ class Signal:
                     observer,
                     e,
                 )
+                return None
 
         deferreds = [run_in_background(do, o) for o in self.observers]