summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-05-24 08:59:31 -0400
committerGitHub <noreply@github.com>2023-05-24 12:59:31 +0000
commit1f55c04cbca6dc56085896dd980defa26ffe3b5b (patch)
treedcc51ffeee2c83f78379f58165ef8f3f83c915f0
parentFix `@trace` not wrapping some state methods that return coroutines correctly... (diff)
downloadsynapse-1f55c04cbca6dc56085896dd980defa26ffe3b5b.tar.xz
Improve type hints for cached decorator. (#15658)
The cached decorators always return a Deferred, which was not
properly propagated. It was close enough when wrapping coroutines,
but failed if a bare function was wrapped.
-rw-r--r--changelog.d/15658.misc1
-rw-r--r--scripts-dev/mypy_synapse_plugin.py34
-rw-r--r--synapse/storage/databases/main/roommember.py2
-rw-r--r--synapse/util/caches/descriptors.py6
-rw-r--r--tests/appservice/test_appservice.py82
-rw-r--r--tests/storage/test_transactions.py11
6 files changed, 73 insertions, 63 deletions
diff --git a/changelog.d/15658.misc b/changelog.d/15658.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/15658.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index 2c377533c0..8058e9c993 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -18,10 +18,11 @@ can crop up, e.g the cache descriptors.
 
 from typing import Callable, Optional, Type
 
+from mypy.erasetype import remove_instance_last_known_values
 from mypy.nodes import ARG_NAMED_OPT
 from mypy.plugin import MethodSigContext, Plugin
 from mypy.typeops import bind_self
-from mypy.types import CallableType, NoneType, UnionType
+from mypy.types import CallableType, Instance, NoneType, UnionType
 
 
 class SynapsePlugin(Plugin):
@@ -92,10 +93,41 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
     arg_names.append("on_invalidate")
     arg_kinds.append(ARG_NAMED_OPT)  # Arg is an optional kwarg.
 
+    # Finally we ensure the return type is a Deferred.
+    if (
+        isinstance(signature.ret_type, Instance)
+        and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred"
+    ):
+        # If it is already a Deferred, nothing to do.
+        ret_type = signature.ret_type
+    else:
+        ret_arg = None
+        if isinstance(signature.ret_type, Instance):
+            # If a coroutine, wrap the coroutine's return type in a Deferred.
+            if signature.ret_type.type.fullname == "typing.Coroutine":
+                ret_arg = signature.ret_type.args[2]
+
+            # If an awaitable, wrap the awaitable's final value in a Deferred.
+            elif signature.ret_type.type.fullname == "typing.Awaitable":
+                ret_arg = signature.ret_type.args[0]
+
+        # Otherwise, wrap the return value in a Deferred.
+        if ret_arg is None:
+            ret_arg = signature.ret_type
+
+        # This should be able to use ctx.api.named_generic_type, but that doesn't seem
+        # to find the correct symbol for anything more than 1 module deep.
+        #
+        # modules is not part of CheckerPluginInterface. The following is a combination
+        # of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
+        sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred")  # type: ignore[attr-defined]
+        ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
+
     signature = signature.copy_modified(
         arg_types=arg_types,
         arg_names=arg_names,
         arg_kinds=arg_kinds,
+        ret_type=ret_type,
     )
 
     return signature
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e068f27a10..ae9c201b87 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -1099,7 +1099,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
         # `get_joined_hosts` is called with the "current" state group for the
         # room, and so consecutive calls will be for consecutive state groups
         # which point to the previous state group.
-        cache = await self._get_joined_hosts_cache(room_id)  # type: ignore[misc]
+        cache = await self._get_joined_hosts_cache(room_id)
 
         # If the state group in the cache matches, we already have the data we need.
         if state_entry.state_group == cache.state_group:
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 81df71a0c5..8514a75a1c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -220,7 +220,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         self.iterable = iterable
         self.prune_unread_entries = prune_unread_entries
 
-    def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
+    def __get__(
+        self, obj: Optional[Any], owner: Optional[Type]
+    ) -> Callable[..., "defer.Deferred[Any]"]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
             name=self.name,
             max_entries=self.max_entries,
@@ -232,7 +234,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         get_cache_key = self.cache_key_builder
 
         @functools.wraps(self.orig)
-        def _wrapped(*args: Any, **kwargs: Any) -> Any:
+        def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]":
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index dee976356f..66753c60c4 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import re
-from typing import Generator
+from typing import Any, Generator
 from unittest.mock import Mock
 
 from twisted.internet import defer
@@ -49,15 +49,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_user_id_prefix_match(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@irc_foobar:matrix.org"
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
@@ -65,15 +63,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_user_id_prefix_no_match(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@someone_else:matrix.org"
         self.assertFalse(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
@@ -81,17 +77,15 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_room_member_is_checked(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@someone_else:matrix.org"
         self.event.type = "m.room.member"
         self.event.state_key = "@irc_foobar:matrix.org"
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
@@ -99,17 +93,15 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_room_id_match(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_ROOMS].append(
             _regex("!some_prefix.*some_suffix:matrix.org")
         )
         self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
@@ -117,25 +109,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_room_id_no_match(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_ROOMS].append(
             _regex("!some_prefix.*some_suffix:matrix.org")
         )
         self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
         self.assertFalse(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
 
     @defer.inlineCallbacks
-    def test_regex_alias_match(
-        self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
@@ -145,10 +133,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.store.get_local_users_in_room = simple_async_mock([])
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
@@ -192,7 +178,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_alias_no_match(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
@@ -213,7 +199,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_multiple_matches(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
@@ -223,18 +209,14 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.store.get_local_users_in_room = simple_async_mock([])
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
 
     @defer.inlineCallbacks
-    def test_interested_in_self(
-        self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]:
         # make sure invites get through
         self.service.sender = "@appservice:name"
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
@@ -243,18 +225,14 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.event.state_key = self.service.sender
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
 
     @defer.inlineCallbacks
-    def test_member_list_match(
-        self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         # Note that @irc_fo:here is the AS user.
         self.store.get_local_users_in_room = simple_async_mock(
@@ -265,10 +243,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.event.sender = "@xmpp_foobar:matrix.org"
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index db9ee9955e..2fab84a529 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -33,15 +33,14 @@ class TransactionStoreTestCase(HomeserverTestCase):
         destination retries, as well as testing tht we can set and get
         correctly.
         """
-        d = self.store.get_destination_retry_timings("example.com")
-        r = self.get_success(d)
+        r = self.get_success(self.store.get_destination_retry_timings("example.com"))
         self.assertIsNone(r)
 
-        d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
-        self.get_success(d)
+        self.get_success(
+            self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
+        )
 
-        d = self.store.get_destination_retry_timings("example.com")
-        r = self.get_success(d)
+        r = self.get_success(self.store.get_destination_retry_timings("example.com"))
 
         self.assertEqual(
             DestinationRetryTimings(