summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py26
1 files changed, 19 insertions, 7 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 88e56e3302..35544b19fd 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache
+from synapse.util.logcontext import (
+    PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
+)
 
 from . import caches_by_name, DEBUG_CACHES, cache_counter
 
@@ -25,6 +28,7 @@ from twisted.internet import defer
 
 from collections import OrderedDict
 
+import os
 import functools
 import inspect
 import threading
@@ -35,6 +39,9 @@ logger = logging.getLogger(__name__)
 _CacheSentinel = object()
 
 
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+
+
 class Cache(object):
 
     def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
@@ -137,6 +144,8 @@ class CacheDescriptor(object):
     """
     def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
                  inlineCallbacks=False):
+        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
+
         self.orig = orig
 
         if inlineCallbacks:
@@ -149,7 +158,7 @@ class CacheDescriptor(object):
         self.lru = lru
         self.tree = tree
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
 
         if len(self.arg_names) < self.num_args:
             raise Exception(
@@ -190,7 +199,7 @@ class CacheDescriptor(object):
                         defer.returnValue(cached_result)
                     observer.addCallback(check_result)
 
-                return observer
+                return preserve_context_over_deferred(observer)
             except KeyError:
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
@@ -198,6 +207,7 @@ class CacheDescriptor(object):
                 sequence = self.cache.sequence
 
                 ret = defer.maybeDeferred(
+                    preserve_context_over_fn,
                     self.function_to_call,
                     obj, *args, **kwargs
                 )
@@ -211,7 +221,7 @@ class CacheDescriptor(object):
                 ret = ObservableDeferred(ret, consumeErrors=True)
                 self.cache.update(sequence, cache_key, ret)
 
-                return ret.observe()
+                return preserve_context_over_deferred(ret.observe())
 
         wrapped.invalidate = self.cache.invalidate
         wrapped.invalidate_all = self.cache.invalidate_all
@@ -250,7 +260,7 @@ class CacheListDescriptor(object):
         self.num_args = num_args
         self.list_name = list_name
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
         self.list_pos = self.arg_names.index(self.list_name)
 
         self.cache = cache
@@ -299,6 +309,7 @@ class CacheListDescriptor(object):
                 args_to_call[self.list_name] = missing
 
                 ret_d = defer.maybeDeferred(
+                    preserve_context_over_fn,
                     self.function_to_call,
                     **args_to_call
                 )
@@ -308,7 +319,8 @@ class CacheListDescriptor(object):
                 # We need to create deferreds for each arg in the list so that
                 # we can insert the new deferred into the cache.
                 for arg in missing:
-                    observer = ret_d.observe()
+                    with PreserveLoggingContext():
+                        observer = ret_d.observe()
                     observer.addCallback(lambda r, arg: r.get(arg, None), arg)
 
                     observer = ObservableDeferred(observer)
@@ -327,10 +339,10 @@ class CacheListDescriptor(object):
 
                     cached[arg] = res
 
-            return defer.gatherResults(
+            return preserve_context_over_deferred(defer.gatherResults(
                 cached.values(),
                 consumeErrors=True,
-            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
 
         obj.__dict__[self.orig.__name__] = wrapped