summary refs log tree commit diff
path: root/scripts-dev/mypy_synapse_plugin.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-05-24 08:59:31 -0400
committerGitHub <noreply@github.com>2023-05-24 12:59:31 +0000
commit1f55c04cbca6dc56085896dd980defa26ffe3b5b (patch)
treedcc51ffeee2c83f78379f58165ef8f3f83c915f0 /scripts-dev/mypy_synapse_plugin.py
parentFix `@trace` not wrapping some state methods that return coroutines correctly... (diff)
downloadsynapse-1f55c04cbca6dc56085896dd980defa26ffe3b5b.tar.xz
Improve type hints for cached decorator. (#15658)
The cached decorators always return a Deferred, which was not
properly propagated. It was close enough when wrapping coroutines,
but failed if a bare function was wrapped.
Diffstat (limited to 'scripts-dev/mypy_synapse_plugin.py')
-rw-r--r--scripts-dev/mypy_synapse_plugin.py34
1 files changed, 33 insertions, 1 deletions
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index 2c377533c0..8058e9c993 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -18,10 +18,11 @@ can crop up, e.g the cache descriptors.
 
 from typing import Callable, Optional, Type
 
+from mypy.erasetype import remove_instance_last_known_values
 from mypy.nodes import ARG_NAMED_OPT
 from mypy.plugin import MethodSigContext, Plugin
 from mypy.typeops import bind_self
-from mypy.types import CallableType, NoneType, UnionType
+from mypy.types import CallableType, Instance, NoneType, UnionType
 
 
 class SynapsePlugin(Plugin):
@@ -92,10 +93,41 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
     arg_names.append("on_invalidate")
     arg_kinds.append(ARG_NAMED_OPT)  # Arg is an optional kwarg.
 
+    # Finally we ensure the return type is a Deferred.
+    if (
+        isinstance(signature.ret_type, Instance)
+        and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred"
+    ):
+        # If it is already a Deferred, nothing to do.
+        ret_type = signature.ret_type
+    else:
+        ret_arg = None
+        if isinstance(signature.ret_type, Instance):
+            # If a coroutine, wrap the coroutine's return type in a Deferred.
+            if signature.ret_type.type.fullname == "typing.Coroutine":
+                ret_arg = signature.ret_type.args[2]
+
+            # If an awaitable, wrap the awaitable's final value in a Deferred.
+            elif signature.ret_type.type.fullname == "typing.Awaitable":
+                ret_arg = signature.ret_type.args[0]
+
+        # Otherwise, wrap the return value in a Deferred.
+        if ret_arg is None:
+            ret_arg = signature.ret_type
+
+        # This should be able to use ctx.api.named_generic_type, but that doesn't seem
+        # to find the correct symbol for anything more than 1 module deep.
+        #
+        # modules is not part of CheckerPluginInterface. The following is a combination
+        # of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
+        sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred")  # type: ignore[attr-defined]
+        ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
+
     signature = signature.copy_modified(
         arg_types=arg_types,
         arg_names=arg_names,
         arg_kinds=arg_kinds,
+        ret_type=ret_type,
     )
 
     return signature