diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 43f66ec4be..cd48262420 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,22 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import functools
import inspect
import logging
import threading
-from collections import namedtuple
+from typing import Any, Tuple, Union, cast
+from weakref import WeakValueDictionary
from six import itervalues
from prometheus_client import Gauge
+from typing_extensions import Protocol
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
@@ -36,6 +38,20 @@ from . import register_cache
logger = logging.getLogger(__name__)
+CacheKey = Union[Tuple, Any]
+
+
+class _CachedFunction(Protocol):
+ invalidate = None # type: Any
+ invalidate_all = None # type: Any
+ invalidate_many = None # type: Any
+ prefill = None # type: Any
+ cache = None # type: Any
+ num_args = None # type: Any
+
+ def __name__(self):
+ ...
+
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
@@ -65,7 +81,6 @@ class CacheEntry(object):
class Cache(object):
__slots__ = (
"cache",
- "max_entries",
"name",
"keylen",
"thread",
@@ -73,7 +88,29 @@ class Cache(object):
"_pending_deferred_cache",
)
- def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
+ def __init__(
+ self,
+ name: str,
+ max_entries: int = 1000,
+ keylen: int = 1,
+ tree: bool = False,
+ iterable: bool = False,
+ apply_cache_factor_from_config: bool = True,
+ ):
+ """
+ Args:
+ name: The name of the cache
+ max_entries: Maximum amount of entries that the cache will hold
+ keylen: The length of the tuple used as the cache key
+ tree: Use a TreeCache instead of a dict as the underlying cache type
+ iterable: If True, count each item in the cached object as an entry,
+ rather than each cached object
+ apply_cache_factor_from_config: Whether cache factors specified in the
+ config file affect `max_entries`
+
+ Returns:
+ Cache
+ """
cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
@@ -83,6 +120,7 @@ class Cache(object):
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
+ apply_cache_factor_from_config=apply_cache_factor_from_config,
)
self.name = name
@@ -95,6 +133,10 @@ class Cache(object):
collect_callback=self._metrics_collection_callback,
)
+ @property
+ def max_entries(self):
+ return self.cache.max_size
+
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
@@ -245,7 +287,9 @@ class Cache(object):
class _CacheDescriptorBase(object):
- def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+ def __init__(
+ self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
+ ):
self.orig = orig
if inlineCallbacks:
@@ -253,7 +297,7 @@ class _CacheDescriptorBase(object):
else:
self.function_to_call = orig
- arg_spec = inspect.getargspec(orig)
+ arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
if "cache_context" in all_args:
@@ -352,13 +396,11 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_context=cache_context,
)
- max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
-
self.max_entries = max_entries
self.tree = tree
self.iterable = iterable
- def __get__(self, obj, objtype=None):
+ def __get__(self, obj, owner):
cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
@@ -404,7 +446,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig)
- def wrapped(*args, **kwargs):
+ def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -414,7 +456,7 @@ class CacheDescriptor(_CacheDescriptorBase):
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
- kwargs["cache_context"] = _CacheContext(cache, cache_key)
+ kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
try:
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
@@ -422,7 +464,7 @@ class CacheDescriptor(_CacheDescriptorBase):
if isinstance(cached_result_d, ObservableDeferred):
observer = cached_result_d.observe()
else:
- observer = cached_result_d
+ observer = defer.succeed(cached_result_d)
except KeyError:
ret = defer.maybeDeferred(
@@ -440,6 +482,8 @@ class CacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(observer)
+ wrapped = cast(_CachedFunction, _wrapped)
+
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
@@ -464,9 +508,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped function.
- Once wrapped, the function returns either a Deferred which resolves to
- the list of results, or (if all results were cached), just the list of
- results.
+ Once wrapped, the function returns a Deferred which resolves to the list
+ of results.
"""
def __init__(
@@ -600,21 +643,45 @@ class CacheListDescriptor(_CacheDescriptorBase):
)
return make_deferred_yieldable(d)
else:
- return results
+ return defer.succeed(results)
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
-class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
- # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
- # which namedtuple does for us (i.e. two _CacheContext are the same if
- # their caches and keys match). This is important in particular to
- # dedupe when we add callbacks to lru cache nodes, otherwise the number
- # of callbacks would grow.
- def invalidate(self):
- self.cache.invalidate(self.key)
+class _CacheContext:
+ """Holds cache information from the cached function higher in the calling order.
+
+ Can be used to invalidate the higher level cache entry if something changes
+ on a lower level.
+ """
+
+ _cache_context_objects = (
+ WeakValueDictionary()
+ ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
+
+ def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None
+ self._cache = cache
+ self._cache_key = cache_key
+
+ def invalidate(self): # type: () -> None
+ """Invalidates the cache entry referred to by the context."""
+ self._cache.invalidate(self._cache_key)
+
+ @classmethod
+ def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext
+ """Returns an instance constructed with the given arguments.
+
+ A new instance is only created if none already exists.
+ """
+
+ # We make sure there are no identical _CacheContext instances. This is
+ # important in particular to dedupe when we add callbacks to lru cache
+ # nodes, otherwise the number of callbacks would grow.
+ return cls._cache_context_objects.setdefault(
+ (cache, cache_key), cls(cache, cache_key)
+ )
def cached(
|