diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 89f0b38535..e676c2cac4 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -149,8 +149,7 @@ KNOWN_KEYS = {
def intern_string(string):
- """Takes a (potentially) unicode string and interns it if it's ascii
- """
+ """Takes a (potentially) unicode string and interns it if it's ascii"""
if string is None:
return None
@@ -161,8 +160,7 @@ def intern_string(string):
def intern_dict(dictionary):
- """Takes a dictionary and interns well known keys and their values
- """
+ """Takes a dictionary and interns well known keys and their values"""
return {
KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
for key, value in dictionary.items()
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
new file mode 100644
index 0000000000..3ee0f2317a
--- /dev/null
+++ b/synapse/util/caches/cached_call.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
+
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+from synapse.logging.context import make_deferred_yieldable, run_in_background
+
+TV = TypeVar("TV")
+
+
+class CachedCall(Generic[TV]):
+ """A wrapper for asynchronous calls whose results should be shared
+
+ This is useful for wrapping asynchronous functions, where there might be multiple
+ callers, but we only want to call the underlying function once (and have the result
+ returned to all callers).
+
+ Similar results can be achieved via a lock of some form, but that typically requires
+ more boilerplate (and ends up being less efficient).
+
+ Correctly handles Synapse logcontexts (logs and resource usage for the underlying
+ function are logged against the logcontext which is active when get() is first
+ called).
+
+ Example usage:
+
+ _cached_val = CachedCall(_load_prop)
+
+ async def handle_request() -> X:
+ # We can call this multiple times, but it will result in a single call to
+ # _load_prop().
+ return await _cached_val.get()
+
+ async def _load_prop() -> X:
+ await difficult_operation()
+
+
+ The implementation is deliberately single-shot (ie, once the call is initiated,
+ there is no way to ask for it to be run). This keeps the implementation and
+ semantics simple. If you want to make a new call, simply replace the whole
+ CachedCall object.
+ """
+
+ __slots__ = ["_callable", "_deferred", "_result"]
+
+ def __init__(self, f: Callable[[], Awaitable[TV]]):
+ """
+ Args:
+ f: The underlying function. Only one call to this function will be alive
+ at once (per instance of CachedCall)
+ """
+ self._callable = f # type: Optional[Callable[[], Awaitable[TV]]]
+ self._deferred = None # type: Optional[Deferred]
+ self._result = None # type: Union[None, Failure, TV]
+
+ async def get(self) -> TV:
+ """Kick off the call if necessary, and return the result"""
+
+ # Fire off the callable now if this is our first time
+ if not self._deferred:
+ self._deferred = run_in_background(self._callable)
+
+ # we will never need the callable again, so make sure it can be GCed
+ self._callable = None
+
+ # once the deferred completes, store the result. We cannot simply leave the
+ # result in the deferred, since if it's a Failure, GCing the deferred
+ # would then log a critical error about unhandled Failures.
+ def got_result(r):
+ self._result = r
+
+ self._deferred.addBoth(got_result)
+
+ # TODO: consider cancellation semantics. Currently, if the call to get()
+ # is cancelled, the underlying call will continue (and any future calls
+ # will get the result/exception), which I think is *probably* ok, modulo
+ # the fact the underlying call may be logged to a cancelled logcontext,
+ # and any eventual exception may not be reported.
+
+ # we can now await the deferred, and once it completes, return the result.
+ await make_deferred_yieldable(self._deferred)
+
+ # I *think* this is the easiest way to correctly raise a Failure without having
+ # to gut-wrench into the implementation of Deferred.
+ d = Deferred()
+ d.callback(self._result)
+ return await d
+
+
+class RetryOnExceptionCachedCall(Generic[TV]):
+ """A wrapper around CachedCall which will retry the call if an exception is thrown
+
+ This is used in much the same way as CachedCall, but adds some extra functionality
+ so that if the underlying function throws an exception, then the next call to get()
+ will initiate another call to the underlying function. (Any calls to get() which
+ are already pending will raise the exception.)
+ """
+
+ slots = ["_cachedcall"]
+
+ def __init__(self, f: Callable[[], Awaitable[TV]]):
+ async def _wrapper() -> TV:
+ try:
+ return await f()
+ except Exception:
+ # the call raised an exception: replace the underlying CachedCall to
+ # trigger another call next time get() is called
+ self._cachedcall = CachedCall(_wrapper)
+ raise
+
+ self._cachedcall = CachedCall(_wrapper)
+
+ async def get(self) -> TV:
+ return await self._cachedcall.get()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index a924140cdf..4e84379914 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -122,7 +122,8 @@ class _LruCachedFunction(Generic[F]):
def lru_cache(
- max_entries: int = 1000, cache_context: bool = False,
+ max_entries: int = 1000,
+ cache_context: bool = False,
) -> Callable[[F], _LruCachedFunction[F]]:
"""A method decorator that applies a memoizing cache around the function.
@@ -156,7 +157,9 @@ def lru_cache(
def func(orig: F) -> _LruCachedFunction[F]:
desc = LruCacheDescriptor(
- orig, max_entries=max_entries, cache_context=cache_context,
+ orig,
+ max_entries=max_entries,
+ cache_context=cache_context,
)
return cast(_LruCachedFunction[F], desc)
@@ -170,14 +173,18 @@ class LruCacheDescriptor(_CacheDescriptorBase):
sentinel = object()
def __init__(
- self, orig, max_entries: int = 1000, cache_context: bool = False,
+ self,
+ orig,
+ max_entries: int = 1000,
+ cache_context: bool = False,
):
super().__init__(orig, num_args=None, cache_context=cache_context)
self.max_entries = max_entries
def __get__(self, obj, owner):
cache = LruCache(
- cache_name=self.orig.__name__, max_size=self.max_entries,
+ cache_name=self.orig.__name__,
+ max_size=self.max_entries,
) # type: LruCache[CacheKey, Any]
get_cache_key = self.cache_key_builder
@@ -212,7 +219,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
class DeferredCacheDescriptor(_CacheDescriptorBase):
- """ A method decorator that applies a memoizing cache around the function.
+ """A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index c541bf4579..644e9e778a 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -84,8 +84,7 @@ class StreamChangeCache:
return False
def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
- """Returns True if the entity may have been updated since stream_pos
- """
+ """Returns True if the entity may have been updated since stream_pos"""
assert isinstance(stream_pos, int)
if stream_pos < self._earliest_known_stream_pos:
@@ -133,8 +132,7 @@ class StreamChangeCache:
return result
def has_any_entity_changed(self, stream_pos: int) -> bool:
- """Returns if any entity has changed
- """
+ """Returns if any entity has changed"""
assert type(stream_pos) is int
if not self._cache:
|