diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 48dcbafeef..af65bfe7b8 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -16,6 +16,7 @@ import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError, logcontext
+from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii
@@ -25,7 +26,6 @@ from . import register_cache
from twisted.internet import defer
from collections import namedtuple
-import os
import functools
import inspect
import threading
@@ -37,9 +37,6 @@ logger = logging.getLogger(__name__)
_CacheSentinel = object()
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
class CacheEntry(object):
__slots__ = [
"deferred", "sequence", "callbacks", "invalidated"
@@ -404,6 +401,7 @@ class CacheDescriptor(_CacheDescriptorBase):
wrapped.invalidate_all = cache.invalidate_all
wrapped.cache = cache
+ wrapped.num_args = self.num_args
obj.__dict__[self.orig.__name__] = wrapped
@@ -451,8 +449,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
)
def __get__(self, obj, objtype=None):
-
- cache = getattr(obj, self.cached_method_name).cache
+ cached_method = getattr(obj, self.cached_method_name)
+ cache = cached_method.cache
+ num_args = cached_method.num_args
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
@@ -469,12 +468,23 @@ class CacheListDescriptor(_CacheDescriptorBase):
results = {}
cached_defers = {}
missing = []
- for arg in list_args:
+
+ # If the cache takes a single arg then that is used as the key,
+ # otherwise a tuple is used.
+ if num_args == 1:
+ def cache_get(arg):
+ return cache.get(arg, callback=invalidate_callback)
+ else:
key = list(keyargs)
- key[self.list_pos] = arg
+ def cache_get(arg):
+ key[self.list_pos] = arg
+ return cache.get(tuple(key), callback=invalidate_callback)
+
+ for arg in list_args:
try:
- res = cache.get(tuple(key), callback=invalidate_callback)
+ res = cache_get(arg)
+
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
@@ -505,17 +515,28 @@ class CacheListDescriptor(_CacheDescriptorBase):
observer = ObservableDeferred(observer)
- key = list(keyargs)
- key[self.list_pos] = arg
- cache.set(
- tuple(key), observer,
- callback=invalidate_callback
- )
-
- def invalidate(f, key):
- cache.invalidate(key)
- return f
- observer.addErrback(invalidate, tuple(key))
+ if num_args == 1:
+ cache.set(
+ arg, observer,
+ callback=invalidate_callback
+ )
+
+ def invalidate(f, key):
+ cache.invalidate(key)
+ return f
+ observer.addErrback(invalidate, arg)
+ else:
+ key = list(keyargs)
+ key[self.list_pos] = arg
+ cache.set(
+ tuple(key), observer,
+ callback=invalidate_callback
+ )
+
+ def invalidate(f, key):
+ cache.invalidate(key)
+ return f
+ observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
|