summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/__init__.py6
-rw-r--r--synapse/util/caches/cached_call.py129
-rw-r--r--synapse/util/caches/descriptors.py17
-rw-r--r--synapse/util/caches/stream_change_cache.py6
4 files changed, 145 insertions, 13 deletions
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py

index 89f0b38535..e676c2cac4 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py
@@ -149,8 +149,7 @@ KNOWN_KEYS = { def intern_string(string): - """Takes a (potentially) unicode string and interns it if it's ascii - """ + """Takes a (potentially) unicode string and interns it if it's ascii""" if string is None: return None @@ -161,8 +160,7 @@ def intern_string(string): def intern_dict(dictionary): - """Takes a dictionary and interns well known keys and their values - """ + """Takes a dictionary and interns well known keys and their values""" return { KNOWN_KEYS.get(key, key): _intern_known_values(key, value) for key, value in dictionary.items() diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py new file mode 100644
index 0000000000..3ee0f2317a --- /dev/null +++ b/synapse/util/caches/cached_call.py
@@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union + +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure + +from synapse.logging.context import make_deferred_yieldable, run_in_background + +TV = TypeVar("TV") + + +class CachedCall(Generic[TV]): + """A wrapper for asynchronous calls whose results should be shared + + This is useful for wrapping asynchronous functions, where there might be multiple + callers, but we only want to call the underlying function once (and have the result + returned to all callers). + + Similar results can be achieved via a lock of some form, but that typically requires + more boilerplate (and ends up being less efficient). + + Correctly handles Synapse logcontexts (logs and resource usage for the underlying + function are logged against the logcontext which is active when get() is first + called). + + Example usage: + + _cached_val = CachedCall(_load_prop) + + async def handle_request() -> X: + # We can call this multiple times, but it will result in a single call to + # _load_prop(). + return await _cached_val.get() + + async def _load_prop() -> X: + await difficult_operation() + + + The implementation is deliberately single-shot (ie, once the call is initiated, + there is no way to ask for it to be run). This keeps the implementation and + semantics simple. If you want to make a new call, simply replace the whole + CachedCall object. + """ + + __slots__ = ["_callable", "_deferred", "_result"] + + def __init__(self, f: Callable[[], Awaitable[TV]]): + """ + Args: + f: The underlying function. Only one call to this function will be alive + at once (per instance of CachedCall) + """ + self._callable = f # type: Optional[Callable[[], Awaitable[TV]]] + self._deferred = None # type: Optional[Deferred] + self._result = None # type: Union[None, Failure, TV] + + async def get(self) -> TV: + """Kick off the call if necessary, and return the result""" + + # Fire off the callable now if this is our first time + if not self._deferred: + self._deferred = run_in_background(self._callable) + + # we will never need the callable again, so make sure it can be GCed + self._callable = None + + # once the deferred completes, store the result. We cannot simply leave the + # result in the deferred, since if it's a Failure, GCing the deferred + # would then log a critical error about unhandled Failures. + def got_result(r): + self._result = r + + self._deferred.addBoth(got_result) + + # TODO: consider cancellation semantics. Currently, if the call to get() + # is cancelled, the underlying call will continue (and any future calls + # will get the result/exception), which I think is *probably* ok, modulo + # the fact the underlying call may be logged to a cancelled logcontext, + # and any eventual exception may not be reported. + + # we can now await the deferred, and once it completes, return the result. + await make_deferred_yieldable(self._deferred) + + # I *think* this is the easiest way to correctly raise a Failure without having + # to gut-wrench into the implementation of Deferred. + d = Deferred() + d.callback(self._result) + return await d + + +class RetryOnExceptionCachedCall(Generic[TV]): + """A wrapper around CachedCall which will retry the call if an exception is thrown + + This is used in much the same way as CachedCall, but adds some extra functionality + so that if the underlying function throws an exception, then the next call to get() + will initiate another call to the underlying function. (Any calls to get() which + are already pending will raise the exception.) + """ + + slots = ["_cachedcall"] + + def __init__(self, f: Callable[[], Awaitable[TV]]): + async def _wrapper() -> TV: + try: + return await f() + except Exception: + # the call raised an exception: replace the underlying CachedCall to + # trigger another call next time get() is called + self._cachedcall = CachedCall(_wrapper) + raise + + self._cachedcall = CachedCall(_wrapper) + + async def get(self) -> TV: + return await self._cachedcall.get() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index a924140cdf..4e84379914 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py
@@ -122,7 +122,8 @@ class _LruCachedFunction(Generic[F]): def lru_cache( - max_entries: int = 1000, cache_context: bool = False, + max_entries: int = 1000, + cache_context: bool = False, ) -> Callable[[F], _LruCachedFunction[F]]: """A method decorator that applies a memoizing cache around the function. @@ -156,7 +157,9 @@ def lru_cache( def func(orig: F) -> _LruCachedFunction[F]: desc = LruCacheDescriptor( - orig, max_entries=max_entries, cache_context=cache_context, + orig, + max_entries=max_entries, + cache_context=cache_context, ) return cast(_LruCachedFunction[F], desc) @@ -170,14 +173,18 @@ class LruCacheDescriptor(_CacheDescriptorBase): sentinel = object() def __init__( - self, orig, max_entries: int = 1000, cache_context: bool = False, + self, + orig, + max_entries: int = 1000, + cache_context: bool = False, ): super().__init__(orig, num_args=None, cache_context=cache_context) self.max_entries = max_entries def __get__(self, obj, owner): cache = LruCache( - cache_name=self.orig.__name__, max_size=self.max_entries, + cache_name=self.orig.__name__, + max_size=self.max_entries, ) # type: LruCache[CacheKey, Any] get_cache_key = self.cache_key_builder @@ -212,7 +219,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): class DeferredCacheDescriptor(_CacheDescriptorBase): - """ A method decorator that applies a memoizing cache around the function. + """A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that fail are removed from the cache. diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index c541bf4579..644e9e778a 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py
@@ -84,8 +84,7 @@ class StreamChangeCache: return False def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: - """Returns True if the entity may have been updated since stream_pos - """ + """Returns True if the entity may have been updated since stream_pos""" assert isinstance(stream_pos, int) if stream_pos < self._earliest_known_stream_pos: @@ -133,8 +132,7 @@ class StreamChangeCache: return result def has_any_entity_changed(self, stream_pos: int) -> bool: - """Returns if any entity has changed - """ + """Returns if any entity has changed""" assert type(stream_pos) is int if not self._cache: