diff options
-rw-r--r-- | scripts-dev/mypy_synapse_plugin.py | 55 |
1 files changed, 18 insertions, 37 deletions
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 76a81c8612..62f1703728 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -46,13 +46,13 @@ class SynapsePlugin(Plugin): def get_method_signature_hook( self, fullname: str ) -> Optional[Callable[[MethodSigContext], CallableType]]: - if fullname.startswith( - ( - "synapse.util.caches.descriptors.CachedFunction.__call__", - "synapse.util.caches.descriptors._LruCachedFunction.__call__", - ) - ): - return cached_function_method_signature + # if fullname.startswith( + # ( + # "synapse.util.caches.descriptors.CachedFunction.__call__", + # "synapse.util.caches.descriptors._LruCachedFunction.__call__", + # ) + # ): + # return cached_function_method_signature if fullname in ( "synapse.util.caches.descriptors._CachedFunctionDescriptor.__call__", @@ -67,7 +67,12 @@ class SynapsePlugin(Plugin): ) -> Optional[Callable[[AttributeContext], mypy.types.Type]]: # Anything in synapse could be wrapped with the cached decorator, but # we know that anything else is *not*. - if fullname.startswith("synapse."): + if fullname.startswith( + ( + "synapse.util.caches.descriptors.CachedFunction", + "synapse.util.caches.descriptors._LruCachedFunction", + ) + ): return cached_function_method_attribute return None @@ -99,35 +104,11 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: def cached_function_method_attribute(ctx: AttributeContext) -> mypy.types.Type: if isinstance(ctx.default_attr_type, Instance): - if ( - ctx.default_attr_type.type.fullname - == "synapse.util.caches.descriptors.CachedFunction" - ): - if getattr(ctx.context, "name") == "did_forget": - breakpoint() - - # Unwrap the wrapped function. - wrapped_signature = ctx.default_attr_type.args[0] - assert isinstance(wrapped_signature, CallableType) - - # 1. Mark this as a bound function signature. - signature: CallableType = bind_self(wrapped_signature) - - # 4. Ensure the return type is a Deferred. - ret_arg = _get_true_return_type(signature) - - # This should be able to use api.named_generic_type, but that doesn't seem - # to find the correct symbol for anything more than 1 module deep. - # - # modules is not part of CheckerPluginInterface. The following is a combination - # of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo. - sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined] - ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)]) - - return ctx.default_attr_type.copy_modified( - args=[signature.copy_modified(ret_type=ret_type)] - ) - + wrapped_callable = ctx.default_attr_type.args[0] + assert isinstance(wrapped_callable, CallableType) + return ctx.default_attr_type.copy_modified( + args=[_unwrap_cached_decoratored_function(wrapped_callable, ctx.api)] + ) return ctx.default_attr_type |