summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-05-24 08:59:31 -0400
committerGitHub <noreply@github.com>2023-05-24 12:59:31 +0000
commit1f55c04cbca6dc56085896dd980defa26ffe3b5b (patch)
treedcc51ffeee2c83f78379f58165ef8f3f83c915f0 /synapse
parentFix `@trace` not wrapping some state methods that return coroutines correctly... (diff)
downloadsynapse-1f55c04cbca6dc56085896dd980defa26ffe3b5b.tar.xz
Improve type hints for cached decorator. (#15658)
The cached decorators always return a Deferred, which was not
properly propagated. It was close enough when wrapping coroutines,
but failed if a bare function was wrapped.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/main/roommember.py2
-rw-r--r--synapse/util/caches/descriptors.py6
2 files changed, 5 insertions, 3 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e068f27a10..ae9c201b87 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -1099,7 +1099,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
         # `get_joined_hosts` is called with the "current" state group for the
         # room, and so consecutive calls will be for consecutive state groups
         # which point to the previous state group.
-        cache = await self._get_joined_hosts_cache(room_id)  # type: ignore[misc]
+        cache = await self._get_joined_hosts_cache(room_id)
 
         # If the state group in the cache matches, we already have the data we need.
         if state_entry.state_group == cache.state_group:
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 81df71a0c5..8514a75a1c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -220,7 +220,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         self.iterable = iterable
         self.prune_unread_entries = prune_unread_entries
 
-    def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
+    def __get__(
+        self, obj: Optional[Any], owner: Optional[Type]
+    ) -> Callable[..., "defer.Deferred[Any]"]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
             name=self.name,
             max_entries=self.max_entries,
@@ -232,7 +234,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         get_cache_key = self.cache_key_builder
 
         @functools.wraps(self.orig)
-        def _wrapped(*args: Any, **kwargs: Any) -> Any:
+        def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]":
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)