summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--scripts-dev/mypy_synapse_plugin.py55
1 files changed, 18 insertions, 37 deletions
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index 76a81c8612..62f1703728 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -46,13 +46,13 @@ class SynapsePlugin(Plugin):
     def get_method_signature_hook(
         self, fullname: str
     ) -> Optional[Callable[[MethodSigContext], CallableType]]:
-        if fullname.startswith(
-            (
-                "synapse.util.caches.descriptors.CachedFunction.__call__",
-                "synapse.util.caches.descriptors._LruCachedFunction.__call__",
-            )
-        ):
-            return cached_function_method_signature
+        # if fullname.startswith(
+        #     (
+        #         "synapse.util.caches.descriptors.CachedFunction.__call__",
+        #         "synapse.util.caches.descriptors._LruCachedFunction.__call__",
+        #     )
+        # ):
+        #     return cached_function_method_signature
 
         if fullname in (
             "synapse.util.caches.descriptors._CachedFunctionDescriptor.__call__",
@@ -67,7 +67,12 @@ class SynapsePlugin(Plugin):
     ) -> Optional[Callable[[AttributeContext], mypy.types.Type]]:
         # Anything in synapse could be wrapped with the cached decorator, but
         # we know that anything else is *not*.
-        if fullname.startswith("synapse."):
+        if fullname.startswith(
+            (
+                "synapse.util.caches.descriptors.CachedFunction",
+                "synapse.util.caches.descriptors._LruCachedFunction",
+            )
+        ):
             return cached_function_method_attribute
         return None
 
@@ -99,35 +104,11 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
 
 def cached_function_method_attribute(ctx: AttributeContext) -> mypy.types.Type:
     if isinstance(ctx.default_attr_type, Instance):
-        if (
-            ctx.default_attr_type.type.fullname
-            == "synapse.util.caches.descriptors.CachedFunction"
-        ):
-            if getattr(ctx.context, "name") == "did_forget":
-                breakpoint()
-
-            # Unwrap the wrapped function.
-            wrapped_signature = ctx.default_attr_type.args[0]
-            assert isinstance(wrapped_signature, CallableType)
-
-            # 1. Mark this as a bound function signature.
-            signature: CallableType = bind_self(wrapped_signature)
-
-            # 4. Ensure the return type is a Deferred.
-            ret_arg = _get_true_return_type(signature)
-
-            # This should be able to use api.named_generic_type, but that doesn't seem
-            # to find the correct symbol for anything more than 1 module deep.
-            #
-            # modules is not part of CheckerPluginInterface. The following is a combination
-            # of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
-            sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred")  # type: ignore[attr-defined]
-            ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
-
-            return ctx.default_attr_type.copy_modified(
-                args=[signature.copy_modified(ret_type=ret_type)]
-            )
-
+        wrapped_callable = ctx.default_attr_type.args[0]
+        assert isinstance(wrapped_callable, CallableType)
+        return ctx.default_attr_type.copy_modified(
+            args=[_unwrap_cached_decoratored_function(wrapped_callable, ctx.api)]
+        )
     return ctx.default_attr_type