summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/util/caches/cached_call.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index 891bee0b33..e58dd91eda 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import enum
 from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
 
 from twisted.internet.defer import Deferred
@@ -22,6 +22,10 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 TV = TypeVar("TV")
 
 
+class _Sentinel(enum.Enum):
+    sentinel = object()
+
+
 class CachedCall(Generic[TV]):
     """A wrapper for asynchronous calls whose results should be shared
 
@@ -65,7 +69,7 @@ class CachedCall(Generic[TV]):
         """
         self._callable: Optional[Callable[[], Awaitable[TV]]] = f
         self._deferred: Optional[Deferred] = None
-        self._result: Union[None, Failure, TV] = None
+        self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel
 
     async def get(self) -> TV:
         """Kick off the call if necessary, and return the result"""
@@ -78,8 +82,9 @@ class CachedCall(Generic[TV]):
             self._callable = None
 
             # once the deferred completes, store the result. We cannot simply leave the
-            # result in the deferred, since if it's a Failure, GCing the deferred
-            # would then log a critical error about unhandled Failures.
+            # result in the deferred, since `awaiting` a deferred destroys its result.
+            # (Also, if it's a Failure, GCing the deferred would log a critical error
+            # about unhandled Failures)
             def got_result(r):
                 self._result = r
 
@@ -92,13 +97,15 @@ class CachedCall(Generic[TV]):
         #    and any eventual exception may not be reported.
 
         # we can now await the deferred, and once it completes, return the result.
-        await make_deferred_yieldable(self._deferred)
+        if isinstance(self._result, _Sentinel):
+            await make_deferred_yieldable(self._deferred)
+            assert not isinstance(self._result, _Sentinel)
+
+        if isinstance(self._result, Failure):
+            self._result.raiseException()
+            raise AssertionError("unexpected return from Failure.raiseException")
 
-        # I *think* this is the easiest way to correctly raise a Failure without having
-        # to gut-wrench into the implementation of Deferred.
-        d = Deferred()
-        d.callback(self._result)
-        return await d
+        return self._result
 
 
 class RetryOnExceptionCachedCall(Generic[TV]):