summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-05-09 11:27:39 +0100
committerGitHub <noreply@github.com>2022-05-09 10:27:39 +0000
commitfa0eab9c8e159b698a31fc7cfaafed643f47e284 (patch)
tree10b0b3d1c09fdf88b7c227be9976999878f2f377 /synapse
parentDon't error on unknown receipt types (#12670) (diff)
downloadsynapse-fa0eab9c8e159b698a31fc7cfaafed643f47e284.tar.xz
Use `ParamSpec` in a few places (#12667)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/_base.py14
-rw-r--r--synapse/events/presence_router.py15
-rw-r--r--synapse/module_api/__init__.py17
-rw-r--r--synapse/rest/client/knock.py4
-rw-r--r--synapse/rest/client/transactions.py19
-rw-r--r--synapse/storage/database.py31
-rw-r--r--synapse/storage/databases/main/events.py8
-rw-r--r--synapse/util/async_helpers.py26
-rw-r--r--synapse/util/distributor.py29
-rw-r--r--synapse/util/metrics.py31
-rw-r--r--synapse/util/patch_inline_callbacks.py15
11 files changed, 143 insertions, 66 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index d28b87a3f4..3623c1724d 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -38,6 +38,7 @@ from typing import (
 
 from cryptography.utils import CryptographyDeprecationWarning
 from matrix_common.versionstring import get_distribution_version_string
+from typing_extensions import ParamSpec
 
 import twisted
 from twisted.internet import defer, error, reactor as _reactor
@@ -81,11 +82,12 @@ logger = logging.getLogger(__name__)
 
 # list of tuples of function, args list, kwargs dict
 _sighup_callbacks: List[
-    Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
+    Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]]
 ] = []
+P = ParamSpec("P")
 
 
-def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None:
+def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None:
     """
     Register a function to be called when a SIGHUP occurs.
 
@@ -93,7 +95,9 @@ def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> Non
         func: Function to be called when sent a SIGHUP signal.
         *args, **kwargs: args and kwargs to be passed to the target function.
     """
-    _sighup_callbacks.append((func, args, kwargs))
+    # This type-ignore should be redundant once we use a mypy release with
+    # https://github.com/python/mypy/pull/12668.
+    _sighup_callbacks.append((func, args, kwargs))  # type: ignore[arg-type]
 
 
 def start_worker_reactor(
@@ -214,7 +218,9 @@ def redirect_stdio_to_logs() -> None:
     print("Redirected stdout/stderr to logs")
 
 
-def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None:
+def register_start(
+    cb: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
+) -> None:
     """Register a callback with the reactor, to be called once it is running
 
     This can be used to initialise parts of the system which require an asynchronous
diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py
index 98555c8c0c..8437ce52dc 100644
--- a/synapse/events/presence_router.py
+++ b/synapse/events/presence_router.py
@@ -22,9 +22,12 @@ from typing import (
     List,
     Optional,
     Set,
+    TypeVar,
     Union,
 )
 
+from typing_extensions import ParamSpec
+
 from synapse.api.presence import UserPresenceState
 from synapse.util.async_helpers import maybe_awaitable
 
@@ -40,6 +43,10 @@ GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
 logger = logging.getLogger(__name__)
 
 
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
 def load_legacy_presence_router(hs: "HomeServer") -> None:
     """Wrapper that loads a presence router module configured using the old
     configuration, and registers the hooks they implement.
@@ -63,13 +70,15 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
 
     # All methods that the module provides should be async, but this wasn't enforced
     # in the old module system, so we wrap them if needed
-    def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+    def async_wrapper(
+        f: Optional[Callable[P, R]]
+    ) -> Optional[Callable[P, Awaitable[R]]]:
         # f might be None if the callback isn't implemented by the module. In this
         # case we don't want to register a callback at all so we return None.
         if f is None:
             return None
 
-        def run(*args: Any, **kwargs: Any) -> Awaitable:
+        def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
             # Assertion required because mypy can't prove we won't change `f`
             # back to `None`. See
             # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
@@ -80,7 +89,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
         return run
 
     # Register the hooks through the module API.
-    hooks = {
+    hooks: Dict[str, Optional[Callable[..., Any]]] = {
         hook: async_wrapper(getattr(presence_router, hook, None))
         for hook in presence_router_methods
     }
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 834fe1b62c..73f92d2df8 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -30,6 +30,7 @@ from typing import (
 
 import attr
 import jinja2
+from typing_extensions import ParamSpec
 
 from twisted.internet import defer
 from twisted.web.resource import Resource
@@ -129,6 +130,7 @@ if TYPE_CHECKING:
 
 
 T = TypeVar("T")
+P = ParamSpec("P")
 
 """
 This package defines the 'stable' API which can be used by extension modules which
@@ -799,9 +801,9 @@ class ModuleApi:
     def run_db_interaction(
         self,
         desc: str,
-        func: Callable[..., T],
-        *args: Any,
-        **kwargs: Any,
+        func: Callable[P, T],
+        *args: P.args,
+        **kwargs: P.kwargs,
     ) -> "defer.Deferred[T]":
         """Run a function with a database connection
 
@@ -817,8 +819,9 @@ class ModuleApi:
         Returns:
             Deferred[object]: result of func
         """
+        # type-ignore: See https://github.com/python/mypy/issues/8862
         return defer.ensureDeferred(
-            self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
+            self._store.db_pool.runInteraction(desc, func, *args, **kwargs)  # type: ignore[arg-type]
         )
 
     def complete_sso_login(
@@ -1296,9 +1299,9 @@ class ModuleApi:
 
     async def defer_to_thread(
         self,
-        f: Callable[..., T],
-        *args: Any,
-        **kwargs: Any,
+        f: Callable[P, T],
+        *args: P.args,
+        **kwargs: P.kwargs,
     ) -> T:
         """Runs the given function in a separate thread from Synapse's thread pool.
 
diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index 0152a0c66a..ad025c8a45 100644
--- a/synapse/rest/client/knock.py
+++ b/synapse/rest/client/knock.py
@@ -15,8 +15,6 @@
 import logging
 from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
 
-from twisted.web.server import Request
-
 from synapse.api.constants import Membership
 from synapse.api.errors import SynapseError
 from synapse.http.server import HttpServer
@@ -97,7 +95,7 @@ class KnockRoomAliasServlet(RestServlet):
         return 200, {"room_id": room_id}
 
     def on_PUT(
-        self, request: Request, room_identifier: str, txn_id: str
+        self, request: SynapseRequest, room_identifier: str, txn_id: str
     ) -> Awaitable[Tuple[int, JsonDict]]:
         set_tag("txn_id", txn_id)
 
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 914fb3acf5..61375651bc 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -15,7 +15,9 @@
 """This module contains logic for storing HTTP PUT transactions. This is used
 to ensure idempotency when performing PUTs using the REST API."""
 import logging
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple
+
+from typing_extensions import ParamSpec
 
 from twisted.python.failure import Failure
 from twisted.web.server import Request
@@ -32,6 +34,9 @@ logger = logging.getLogger(__name__)
 CLEANUP_PERIOD_MS = 1000 * 60 * 30  # 30 mins
 
 
+P = ParamSpec("P")
+
+
 class HttpTransactionCache:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
@@ -65,9 +70,9 @@ class HttpTransactionCache:
     def fetch_or_execute_request(
         self,
         request: Request,
-        fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
-        *args: Any,
-        **kwargs: Any,
+        fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
+        *args: P.args,
+        **kwargs: P.kwargs,
     ) -> Awaitable[Tuple[int, JsonDict]]:
         """A helper function for fetch_or_execute which extracts
         a transaction key from the given request.
@@ -82,9 +87,9 @@ class HttpTransactionCache:
     def fetch_or_execute(
         self,
         txn_key: str,
-        fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
-        *args: Any,
-        **kwargs: Any,
+        fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
+        *args: P.args,
+        **kwargs: P.kwargs,
     ) -> Awaitable[Tuple[int, JsonDict]]:
         """Fetches the response for this transaction, or executes the given function
         to produce a response for this transaction.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 2255e55f6f..41f566b648 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -192,7 +192,7 @@ class LoggingDatabaseConnection:
 
 
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
-_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
+_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
 
 P = ParamSpec("P")
 R = TypeVar("R")
@@ -239,7 +239,9 @@ class LoggingTransaction:
         self.after_callbacks = after_callbacks
         self.exception_callbacks = exception_callbacks
 
-    def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
+    def call_after(
+        self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
+    ) -> None:
         """Call the given callback on the main twisted thread after the transaction has
         finished.
 
@@ -256,11 +258,12 @@ class LoggingTransaction:
         # LoggingTransaction isn't expecting there to be any callbacks; assert that
         # is not the case.
         assert self.after_callbacks is not None
-        self.after_callbacks.append((callback, args, kwargs))
+        # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
+        self.after_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
 
     def call_on_exception(
-        self, callback: Callable[..., object], *args: Any, **kwargs: Any
-    ):
+        self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
+    ) -> None:
         """Call the given callback on the main twisted thread after the transaction has
         failed.
 
@@ -274,7 +277,8 @@ class LoggingTransaction:
         # LoggingTransaction isn't expecting there to be any callbacks; assert that
         # is not the case.
         assert self.exception_callbacks is not None
-        self.exception_callbacks.append((callback, args, kwargs))
+        # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
+        self.exception_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
 
     def fetchone(self) -> Optional[Tuple]:
         return self.txn.fetchone()
@@ -549,9 +553,9 @@ class DatabasePool:
         desc: str,
         after_callbacks: List[_CallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
-        func: Callable[..., R],
-        *args: Any,
-        **kwargs: Any,
+        func: Callable[Concatenate[LoggingTransaction, P], R],
+        *args: P.args,
+        **kwargs: P.kwargs,
     ) -> R:
         """Start a new database transaction with the given connection.
 
@@ -581,7 +585,10 @@ class DatabasePool:
         # will fail if we have to repeat the transaction.
         # For now, we just log an error, and hope that it works on the first attempt.
         # TODO: raise an exception.
-        for i, arg in enumerate(args):
+
+        # Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see
+        # https://github.com/python/mypy/pull/12668
+        for i, arg in enumerate(args):  # type: ignore[arg-type, var-annotated]
             if inspect.isgenerator(arg):
                 logger.error(
                     "Programming error: generator passed to new_transaction as "
@@ -589,7 +596,9 @@ class DatabasePool:
                     i,
                     func,
                 )
-        for name, val in kwargs.items():
+        # Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see
+        # https://github.com/python/mypy/pull/12668
+        for name, val in kwargs.items():  # type: ignore[attr-defined]
             if inspect.isgenerator(val):
                 logger.error(
                     "Programming error: generator passed to new_transaction as "
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9a6c2fd47a..ed29a0a5e2 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1648,8 +1648,12 @@ class PersistEventsStore:
         txn.call_after(prefill)
 
     def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
-        # Invalidate the caches for the redacted event, note that these caches
-        # are also cleared as part of event replication in _invalidate_caches_for_event.
+        """Invalidate the caches for the redacted event.
+
+        Note that these caches are also cleared as part of event replication in
+        _invalidate_caches_for_event.
+        """
+        assert event.redacts is not None
         txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
         txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
         txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index e27c5d298f..b91020117f 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -42,7 +42,7 @@ from typing import (
 )
 
 import attr
-from typing_extensions import AsyncContextManager, Literal
+from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec
 
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
@@ -237,9 +237,16 @@ async def concurrently_execute(
     )
 
 
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
 async def yieldable_gather_results(
-    func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
-) -> List[T]:
+    func: Callable[Concatenate[T, P], Awaitable[R]],
+    iter: Iterable[T],
+    *args: P.args,
+    **kwargs: P.kwargs,
+) -> List[R]:
     """Executes the function with each argument concurrently.
 
     Args:
@@ -255,7 +262,15 @@ async def yieldable_gather_results(
     try:
         return await make_deferred_yieldable(
             defer.gatherResults(
-                [run_in_background(func, item, *args, **kwargs) for item in iter],
+                # type-ignore: mypy reports two errors:
+                # error: Argument 1 to "run_in_background" has incompatible type
+                #     "Callable[[T, **P], Awaitable[R]]"; expected
+                #     "Callable[[T, **P], Awaitable[R]]"  [arg-type]
+                # error: Argument 2 to "run_in_background" has incompatible type
+                #     "T"; expected "[T, **P.args]"  [arg-type]
+                # The former looks like a mypy bug, and the latter looks like a
+                # false positive.
+                [run_in_background(func, item, *args, **kwargs) for item in iter],  # type: ignore[arg-type]
                 consumeErrors=True,
             )
         )
@@ -577,9 +592,6 @@ class ReadWriteLock:
         return _ctx_manager()
 
 
-R = TypeVar("R")
-
-
 def timeout_deferred(
     deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime
 ) -> "defer.Deferred[_T]":
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 91837655f8..b580bdd0de 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,7 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, Dict, List
+from typing import (
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Generic,
+    List,
+    Optional,
+    TypeVar,
+    Union,
+)
+
+from typing_extensions import ParamSpec
 
 from twisted.internet import defer
 
@@ -75,7 +87,11 @@ class Distributor:
         run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
 
 
-class Signal:
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+class Signal(Generic[P]):
     """A Signal is a dispatch point that stores a list of callables as
     observers of it.
 
@@ -87,16 +103,16 @@ class Signal:
 
     def __init__(self, name: str):
         self.name: str = name
-        self.observers: List[Callable] = []
+        self.observers: List[Callable[P, Any]] = []
 
-    def observe(self, observer: Callable) -> None:
+    def observe(self, observer: Callable[P, Any]) -> None:
         """Adds a new callable to the observer list which will be invoked by
         the 'fire' method.
 
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
-    def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
+    def fire(self, *args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[List[Any]]":
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
         not an error to fire a signal with no observers.
@@ -104,7 +120,7 @@ class Signal:
         Returns a Deferred that will complete when all the observers have
         completed."""
 
-        async def do(observer: Callable[..., Any]) -> Any:
+        async def do(observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]:
             try:
                 return await maybe_awaitable(observer(*args, **kwargs))
             except Exception as e:
@@ -114,6 +130,7 @@ class Signal:
                     observer,
                     e,
                 )
+                return None
 
         deferreds = [run_in_background(do, o) for o in self.observers]
 
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 98ee49af6e..bc3b4938ea 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -15,10 +15,10 @@
 import logging
 from functools import wraps
 from types import TracebackType
-from typing import Any, Callable, Optional, Type, TypeVar, cast
+from typing import Awaitable, Callable, Optional, Type, TypeVar
 
 from prometheus_client import Counter
-from typing_extensions import Protocol
+from typing_extensions import Concatenate, ParamSpec, Protocol
 
 from synapse.logging.context import (
     ContextResourceUsage,
@@ -72,16 +72,21 @@ in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge(
 )
 
 
-T = TypeVar("T", bound=Callable[..., Any])
+P = ParamSpec("P")
+R = TypeVar("R")
 
 
 class HasClock(Protocol):
     clock: Clock
 
 
-def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
-    """
-    Used to decorate an async function with a `Measure` context manager.
+def measure_func(
+    name: Optional[str] = None,
+) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
+    """Decorate an async method with a `Measure` context manager.
+
+    The Measure is created using `self.clock`; it should only be used to decorate
+    methods in classes defining an instance-level `clock` attribute.
 
     Usage:
 
@@ -97,18 +102,24 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
 
     """
 
-    def wrapper(func: T) -> T:
+    def wrapper(
+        func: Callable[Concatenate[HasClock, P], Awaitable[R]]
+    ) -> Callable[P, Awaitable[R]]:
         block_name = func.__name__ if name is None else name
 
         @wraps(func)
-        async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any:
+        async def measured_func(self: HasClock, *args: P.args, **kwargs: P.kwargs) -> R:
             with Measure(self.clock, block_name):
                 r = await func(self, *args, **kwargs)
             return r
 
-        return cast(T, measured_func)
+        # There are some shenanigans here, because we're decorating a method but
+        # explicitly making use of the `self` parameter. The key thing here is that the
+        # return type within the return type for `measure_func` itself describes how the
+        # decorated function will be called.
+        return measured_func  # type: ignore[return-value]
 
-    return wrapper
+    return wrapper  # type: ignore[return-value]
 
 
 class Measure:
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index dace68666c..f97f98a057 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -16,6 +16,8 @@ import functools
 import sys
 from typing import Any, Callable, Generator, List, TypeVar, cast
 
+from typing_extensions import ParamSpec
+
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
 from twisted.python.failure import Failure
@@ -25,6 +27,7 @@ _already_patched = False
 
 
 T = TypeVar("T")
+P = ParamSpec("P")
 
 
 def do_patch() -> None:
@@ -41,13 +44,13 @@ def do_patch() -> None:
         return
 
     def new_inline_callbacks(
-        f: Callable[..., Generator["Deferred[object]", object, T]]
-    ) -> Callable[..., "Deferred[T]"]:
+        f: Callable[P, Generator["Deferred[object]", object, T]]
+    ) -> Callable[P, "Deferred[T]"]:
         @functools.wraps(f)
-        def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]":
+        def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]":
             start_context = current_context()
             changes: List[str] = []
-            orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks(
+            orig: Callable[P, "Deferred[T]"] = orig_inline_callbacks(
                 _check_yield_points(f, changes)
             )
 
@@ -115,7 +118,7 @@ def do_patch() -> None:
 
 
 def _check_yield_points(
-    f: Callable[..., Generator["Deferred[object]", object, T]],
+    f: Callable[P, Generator["Deferred[object]", object, T]],
     changes: List[str],
 ) -> Callable:
     """Wraps a generator that is about to be passed to defer.inlineCallbacks
@@ -138,7 +141,7 @@ def _check_yield_points(
 
     @functools.wraps(f)
     def check_yield_points_inner(
-        *args: Any, **kwargs: Any
+        *args: P.args, **kwargs: P.kwargs
     ) -> Generator["Deferred[object]", object, T]:
         gen = f(*args, **kwargs)