diff --git a/tests/server.py b/tests/server.py
index 0434d1a9c0..f719622fa1 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -100,6 +100,7 @@ logger = logging.getLogger(__name__)
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
P = ParamSpec("P")
+R = TypeVar("R")
class TimedOutException(Exception):
@@ -559,7 +560,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
for database in server.get_datastores().databases:
pool = database._db_pool
- def runWithConnection(func, *args, **kwargs):
+ def runWithConnection(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> "Deferred[R]":
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
@@ -569,7 +570,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
**kwargs,
)
- def runInteraction(interaction, *args, **kwargs):
+ def runInteraction(interaction: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> "Deferred[R]":
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
|