summary refs log tree commit diff
path: root/scripts-dev
diff options
context:
space:
mode:
Diffstat (limited to 'scripts-dev')
-rw-r--r--scripts-dev/mypy_synapse_plugin.py34
1 files changed, 33 insertions, 1 deletions
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index 2c377533c0..8058e9c993 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -18,10 +18,11 @@ can crop up, e.g the cache descriptors.
 
 from typing import Callable, Optional, Type
 
+from mypy.erasetype import remove_instance_last_known_values
 from mypy.nodes import ARG_NAMED_OPT
 from mypy.plugin import MethodSigContext, Plugin
 from mypy.typeops import bind_self
-from mypy.types import CallableType, NoneType, UnionType
+from mypy.types import CallableType, Instance, NoneType, UnionType
 
 
 class SynapsePlugin(Plugin):
@@ -92,10 +93,41 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
     arg_names.append("on_invalidate")
     arg_kinds.append(ARG_NAMED_OPT)  # Arg is an optional kwarg.
 
+    # Finally we ensure the return type is a Deferred.
+    if (
+        isinstance(signature.ret_type, Instance)
+        and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred"
+    ):
+        # If it is already a Deferred, nothing to do.
+        ret_type = signature.ret_type
+    else:
+        ret_arg = None
+        if isinstance(signature.ret_type, Instance):
+            # If a coroutine, wrap the coroutine's return type in a Deferred.
+            if signature.ret_type.type.fullname == "typing.Coroutine":
+                ret_arg = signature.ret_type.args[2]
+
+            # If an awaitable, wrap the awaitable's final value in a Deferred.
+            elif signature.ret_type.type.fullname == "typing.Awaitable":
+                ret_arg = signature.ret_type.args[0]
+
+        # Otherwise, wrap the return value in a Deferred.
+        if ret_arg is None:
+            ret_arg = signature.ret_type
+
+        # This should be able to use ctx.api.named_generic_type, but that doesn't seem
+        # to find the correct symbol for anything more than 1 module deep.
+        #
+        # modules is not part of CheckerPluginInterface. The following is a combination
+        # of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
+        sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred")  # type: ignore[attr-defined]
+        ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
+
     signature = signature.copy_modified(
         arg_types=arg_types,
         arg_names=arg_names,
         arg_kinds=arg_kinds,
+        ret_type=ret_type,
     )
 
     return signature