diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index 891bee0b33..e58dd91eda 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import enum
from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
from twisted.internet.defer import Deferred
@@ -22,6 +22,10 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
TV = TypeVar("TV")
+class _Sentinel(enum.Enum):
+ sentinel = object()
+
+
class CachedCall(Generic[TV]):
"""A wrapper for asynchronous calls whose results should be shared
@@ -65,7 +69,7 @@ class CachedCall(Generic[TV]):
"""
self._callable: Optional[Callable[[], Awaitable[TV]]] = f
self._deferred: Optional[Deferred] = None
- self._result: Union[None, Failure, TV] = None
+ self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel
async def get(self) -> TV:
"""Kick off the call if necessary, and return the result"""
@@ -78,8 +82,9 @@ class CachedCall(Generic[TV]):
self._callable = None
# once the deferred completes, store the result. We cannot simply leave the
- # result in the deferred, since if it's a Failure, GCing the deferred
- # would then log a critical error about unhandled Failures.
+ # result in the deferred, since `awaiting` a deferred destroys its result.
+ # (Also, if it's a Failure, GCing the deferred would log a critical error
+ # about unhandled Failures)
def got_result(r):
self._result = r
@@ -92,13 +97,15 @@ class CachedCall(Generic[TV]):
# and any eventual exception may not be reported.
# we can now await the deferred, and once it completes, return the result.
- await make_deferred_yieldable(self._deferred)
+ if isinstance(self._result, _Sentinel):
+ await make_deferred_yieldable(self._deferred)
+ assert not isinstance(self._result, _Sentinel)
+
+ if isinstance(self._result, Failure):
+ self._result.raiseException()
+ raise AssertionError("unexpected return from Failure.raiseException")
- # I *think* this is the easiest way to correctly raise a Failure without having
- # to gut-wrench into the implementation of Deferred.
- d = Deferred()
- d.callback(self._result)
- return await d
+ return self._result
class RetryOnExceptionCachedCall(Generic[TV]):
|