diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index 2c377533c0..8058e9c993 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -18,10 +18,11 @@ can crop up, e.g the cache descriptors.
from typing import Callable, Optional, Type
+from mypy.erasetype import remove_instance_last_known_values
from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
-from mypy.types import CallableType, NoneType, UnionType
+from mypy.types import CallableType, Instance, NoneType, UnionType
class SynapsePlugin(Plugin):
@@ -92,10 +93,41 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
arg_names.append("on_invalidate")
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
+ # Finally we ensure the return type is a Deferred.
+ if (
+ isinstance(signature.ret_type, Instance)
+ and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred"
+ ):
+ # If it is already a Deferred, nothing to do.
+ ret_type = signature.ret_type
+ else:
+ ret_arg = None
+ if isinstance(signature.ret_type, Instance):
+ # If a coroutine, wrap the coroutine's return type in a Deferred.
+ if signature.ret_type.type.fullname == "typing.Coroutine":
+ ret_arg = signature.ret_type.args[2]
+
+ # If an awaitable, wrap the awaitable's final value in a Deferred.
+ elif signature.ret_type.type.fullname == "typing.Awaitable":
+ ret_arg = signature.ret_type.args[0]
+
+ # Otherwise, wrap the return value in a Deferred.
+ if ret_arg is None:
+ ret_arg = signature.ret_type
+
+ # This should be able to use ctx.api.named_generic_type, but that doesn't seem
+ # to find the correct symbol for anything more than 1 module deep.
+ #
+ # modules is not part of CheckerPluginInterface. The following is a combination
+ # of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
+ sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined]
+ ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
+
signature = signature.copy_modified(
arg_types=arg_types,
arg_names=arg_names,
arg_kinds=arg_kinds,
+ ret_type=ret_type,
)
return signature
|