From 208e1d3eb345dca12e25696e30cee7e788b65ae2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 3 Sep 2020 15:38:32 +0100 Subject: Fix typing for `@cached` wrapped functions (#8240) This requires adding a mypy plugin to fiddle with the type signatures a bit. --- synapse/handlers/federation.py | 10 ++++----- synapse/util/caches/descriptors.py | 42 +++++++++++++++++++++++++------------- 2 files changed, 33 insertions(+), 19 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bd8efbb768..310c7f7138 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -440,11 +440,11 @@ class FederationHandler(BaseHandler): if not prevs - seen: return - latest = await self.store.get_latest_event_ids_in_room(room_id) + latest_list = await self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us - latest = set(latest) + latest = set(latest_list) latest |= seen logger.info( @@ -781,7 +781,7 @@ class FederationHandler(BaseHandler): # keys across all devices. current_keys = [ key - for device in cached_devices + for device in cached_devices.values() for key in device.get("keys", {}).get("keys", {}).values() ] @@ -2119,8 +2119,8 @@ class FederationHandler(BaseHandler): if backfilled or event.internal_metadata.is_outlier(): return - extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id) - extrem_ids = set(extrem_ids) + extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids = set(extrem_ids_list) prev_event_ids = set(event.prev_event_ids()) if extrem_ids == prev_event_ids: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 49d9fddcf0..825810eb16 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,11 +18,10 @@ import functools import inspect import logging import threading -from typing import Any, Tuple, Union, cast +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from weakref import WeakValueDictionary from prometheus_client import Gauge -from typing_extensions import Protocol from twisted.internet import defer @@ -38,8 +37,10 @@ logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] +F = TypeVar("F", bound=Callable[..., Any]) -class _CachedFunction(Protocol): + +class _CachedFunction(Generic[F]): invalidate = None # type: Any invalidate_all = None # type: Any invalidate_many = None # type: Any @@ -47,8 +48,11 @@ class _CachedFunction(Protocol): cache = None # type: Any num_args = None # type: Any - def __name__(self): - ... + __name__ = None # type: str + + # Note: This function signature is actually fiddled with by the synapse mypy + # plugin to a) make it a bound method, and b) remove any `cache_context` arg. + __call__ = None # type: F cache_pending_metric = Gauge( @@ -123,7 +127,7 @@ class Cache(object): self.name = name self.keylen = keylen - self.thread = None + self.thread = None # type: Optional[threading.Thread] self.metrics = register_cache( "cache", name, @@ -662,9 +666,13 @@ class _CacheContext: def cached( - max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False -): - return lambda orig: CacheDescriptor( + max_entries: int = 1000, + num_args: Optional[int] = None, + tree: bool = False, + cache_context: bool = False, + iterable: bool = False, +) -> Callable[[F], _CachedFunction[F]]: + func = lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, @@ -673,8 +681,12 @@ def cached( iterable=iterable, ) + return cast(Callable[[F], _CachedFunction[F]], func) -def cachedList(cached_method_name, list_name, num_args=None): + +def cachedList( + cached_method_name: str, list_name: str, num_args: Optional[int] = None +) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -684,11 +696,11 @@ def cachedList(cached_method_name, list_name, num_args=None): cache. Args: - cached_method_name (str): The name of the single-item lookup method. + cached_method_name: The name of the single-item lookup method. This is only used to find the cache to use. - list_name (str): The name of the argument that is the list to use to + list_name: The name of the argument that is the list to use to do batch lookups in the cache. - num_args (int): Number of arguments to use as the key in the cache + num_args: Number of arguments to use as the key in the cache (including list_name). Defaults to all named parameters. Example: @@ -702,9 +714,11 @@ def cachedList(cached_method_name, list_name, num_args=None): def batch_do_something(self, first_arg, second_args): ... """ - return lambda orig: CacheListDescriptor( + func = lambda orig: CacheListDescriptor( orig, cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, ) + + return cast(Callable[[F], _CachedFunction[F]], func) -- cgit 1.4.1