summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12089.bugfix1
-rw-r--r--synapse/util/caches/descriptors.py64
-rw-r--r--tests/util/caches/test_descriptors.py10
3 files changed, 38 insertions, 37 deletions
diff --git a/changelog.d/12089.bugfix b/changelog.d/12089.bugfix
new file mode 100644
index 0000000000..27172c4828
--- /dev/null
+++ b/changelog.d/12089.bugfix
@@ -0,0 +1 @@
+Fix occasional 'Unhandled error in Deferred' error message.
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index df4fb156c2..1cdead02f1 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,7 @@ import inspect
 import logging
 from typing import (
     Any,
+    Awaitable,
     Callable,
     Dict,
     Generic,
@@ -346,15 +347,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given an iterable of keys it looks in the cache to find any hits, then passes
-    the tuple of missing keys to the wrapped function.
+    the set of missing keys to the wrapped function.
 
-    Once wrapped, the function returns a Deferred which resolves to the list
-    of results.
+    Once wrapped, the function returns a Deferred which resolves to a Dict mapping from
+    input key to output value.
     """
 
     def __init__(
         self,
-        orig: Callable[..., Any],
+        orig: Callable[..., Awaitable[Dict]],
         cached_method_name: str,
         list_name: str,
         num_args: Optional[int] = None,
@@ -385,13 +386,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
 
     def __get__(
         self, obj: Optional[Any], objtype: Optional[Type] = None
-    ) -> Callable[..., Any]:
+    ) -> Callable[..., "defer.Deferred[Dict[Hashable, Any]]"]:
         cached_method = getattr(obj, self.cached_method_name)
         cache: DeferredCache[CacheKey, Any] = cached_method.cache
         num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
-        def wrapped(*args: Any, **kwargs: Any) -> Any:
+        def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
             # If we're passed a cache_context then we'll want to call its
             # invalidate() whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -444,39 +445,38 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     deferred: "defer.Deferred[Any]" = defer.Deferred()
                     deferreds_map[arg] = deferred
                     key = arg_to_cache_key(arg)
-                    cache.set(key, deferred, callback=invalidate_callback)
+                    cached_defers.append(
+                        cache.set(key, deferred, callback=invalidate_callback)
+                    )
 
                 def complete_all(res: Dict[Hashable, Any]) -> None:
-                    # the wrapped function has completed. It returns a
-                    # a dict. We can now resolve the observable deferreds in
-                    # the cache and update our own result map.
-                    for e in missing:
+                    # the wrapped function has completed. It returns a dict.
+                    # We can now update our own result map, and then resolve the
+                    # observable deferreds in the cache.
+                    for e, d1 in deferreds_map.items():
                         val = res.get(e, None)
-                        deferreds_map[e].callback(val)
+                        # make sure we update the results map before running the
+                        # deferreds, because as soon as we run the last deferred, the
+                        # gatherResults() below will complete and return the result
+                        # dict to our caller.
                         results[e] = val
+                        d1.callback(val)
 
-                def errback(f: Failure) -> Failure:
-                    # the wrapped function has failed. Invalidate any cache
-                    # entries we're supposed to be populating, and fail
-                    # their deferreds.
-                    for e in missing:
-                        key = arg_to_cache_key(e)
-                        cache.invalidate(key)
-                        deferreds_map[e].errback(f)
-
-                    # return the failure, to propagate to our caller.
-                    return f
+                def errback_all(f: Failure) -> None:
+                    # the wrapped function has failed. Propagate the failure into
+                    # the cache, which will invalidate the entry, and cause the
+                    # relevant cached_deferreds to fail, which will propagate the
+                    # failure to our caller.
+                    for d1 in deferreds_map.values():
+                        d1.errback(f)
 
                 args_to_call = dict(arg_dict)
-                # copy the missing set before sending it to the callee, to guard against
-                # modification.
-                args_to_call[self.list_name] = tuple(missing)
-
-                cached_defers.append(
-                    defer.maybeDeferred(
-                        preserve_fn(self.orig), **args_to_call
-                    ).addCallbacks(complete_all, errback)
-                )
+                args_to_call[self.list_name] = missing
+
+                # dispatch the call, and attach the two handlers
+                defer.maybeDeferred(
+                    preserve_fn(self.orig), **args_to_call
+                ).addCallbacks(complete_all, errback_all)
 
             if cached_defers:
                 d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index b92d3f0c1b..19741ffcda 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -673,14 +673,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             self.assertEqual(current_context(), SENTINEL_CONTEXT)
             r = yield d1
             self.assertEqual(current_context(), c1)
-            obj.mock.assert_called_once_with((10, 20), 2)
+            obj.mock.assert_called_once_with({10, 20}, 2)
             self.assertEqual(r, {10: "fish", 20: "chips"})
             obj.mock.reset_mock()
 
             # a call with different params should call the mock again
             obj.mock.return_value = {30: "peas"}
             r = yield obj.list_fn([20, 30], 2)
-            obj.mock.assert_called_once_with((30,), 2)
+            obj.mock.assert_called_once_with({30}, 2)
             self.assertEqual(r, {20: "chips", 30: "peas"})
             obj.mock.reset_mock()
 
@@ -701,7 +701,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             obj.mock.return_value = {40: "gravy"}
             iterable = (x for x in [10, 40, 40])
             r = yield obj.list_fn(iterable, 2)
-            obj.mock.assert_called_once_with((40,), 2)
+            obj.mock.assert_called_once_with({40}, 2)
             self.assertEqual(r, {10: "fish", 40: "gravy"})
 
     def test_concurrent_lookups(self):
@@ -729,7 +729,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         d3 = obj.list_fn([10])
 
         # the mock should have been called exactly once
-        obj.mock.assert_called_once_with((10,))
+        obj.mock.assert_called_once_with({10})
         obj.mock.reset_mock()
 
         # ... and none of the calls should yet be complete
@@ -771,7 +771,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         # cache miss
         obj.mock.return_value = {10: "fish", 20: "chips"}
         r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
-        obj.mock.assert_called_once_with((10, 20), 2)
+        obj.mock.assert_called_once_with({10, 20}, 2)
         self.assertEqual(r1, {10: "fish", 20: "chips"})
         obj.mock.reset_mock()