summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py52
1 files changed, 38 insertions, 14 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 362944bc51..277854ccbc 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015, 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,6 +17,10 @@ import logging
 from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache
+from synapse.util.logcontext import (
+    PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
+)
 
 from . import caches_by_name, DEBUG_CACHES, cache_counter
 
@@ -36,9 +40,12 @@ _CacheSentinel = object()
 
 class Cache(object):
 
-    def __init__(self, name, max_entries=1000, keylen=1, lru=True):
+    def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
         if lru:
-            self.cache = LruCache(max_size=max_entries)
+            cache_type = TreeCache if tree else dict
+            self.cache = LruCache(
+                max_size=max_entries, keylen=keylen, cache_type=cache_type
+            )
             self.max_entries = None
         else:
             self.cache = OrderedDict()
@@ -99,6 +106,15 @@ class Cache(object):
         self.sequence += 1
         self.cache.pop(key, None)
 
+    def invalidate_many(self, key):
+        self.check_thread()
+        if not isinstance(key, tuple):
+            raise TypeError(
+                "The cache key must be a tuple not %r" % (type(key),)
+            )
+        self.sequence += 1
+        self.cache.del_multi(key)
+
     def invalidate_all(self):
         self.check_thread()
         self.sequence += 1
@@ -122,7 +138,7 @@ class CacheDescriptor(object):
     which can be used to insert values into the cache specifically, without
     calling the calculation function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
+    def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
                  inlineCallbacks=False):
         self.orig = orig
 
@@ -134,8 +150,9 @@ class CacheDescriptor(object):
         self.max_entries = max_entries
         self.num_args = num_args
         self.lru = lru
+        self.tree = tree
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
 
         if len(self.arg_names) < self.num_args:
             raise Exception(
@@ -149,6 +166,7 @@ class CacheDescriptor(object):
             max_entries=self.max_entries,
             keylen=self.num_args,
             lru=self.lru,
+            tree=self.tree,
         )
 
     def __get__(self, obj, objtype=None):
@@ -175,7 +193,7 @@ class CacheDescriptor(object):
                         defer.returnValue(cached_result)
                     observer.addCallback(check_result)
 
-                return observer
+                return preserve_context_over_deferred(observer)
             except KeyError:
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
@@ -183,6 +201,7 @@ class CacheDescriptor(object):
                 sequence = self.cache.sequence
 
                 ret = defer.maybeDeferred(
+                    preserve_context_over_fn,
                     self.function_to_call,
                     obj, *args, **kwargs
                 )
@@ -196,10 +215,11 @@ class CacheDescriptor(object):
                 ret = ObservableDeferred(ret, consumeErrors=True)
                 self.cache.update(sequence, cache_key, ret)
 
-                return ret.observe()
+                return preserve_context_over_deferred(ret.observe())
 
         wrapped.invalidate = self.cache.invalidate
         wrapped.invalidate_all = self.cache.invalidate_all
+        wrapped.invalidate_many = self.cache.invalidate_many
         wrapped.prefill = self.cache.prefill
 
         obj.__dict__[self.orig.__name__] = wrapped
@@ -234,7 +254,7 @@ class CacheListDescriptor(object):
         self.num_args = num_args
         self.list_name = list_name
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
         self.list_pos = self.arg_names.index(self.list_name)
 
         self.cache = cache
@@ -283,6 +303,7 @@ class CacheListDescriptor(object):
                 args_to_call[self.list_name] = missing
 
                 ret_d = defer.maybeDeferred(
+                    preserve_context_over_fn,
                     self.function_to_call,
                     **args_to_call
                 )
@@ -292,7 +313,8 @@ class CacheListDescriptor(object):
                 # We need to create deferreds for each arg in the list so that
                 # we can insert the new deferred into the cache.
                 for arg in missing:
-                    observer = ret_d.observe()
+                    with PreserveLoggingContext():
+                        observer = ret_d.observe()
                     observer.addCallback(lambda r, arg: r.get(arg, None), arg)
 
                     observer = ObservableDeferred(observer)
@@ -311,31 +333,33 @@ class CacheListDescriptor(object):
 
                     cached[arg] = res
 
-            return defer.gatherResults(
+            return preserve_context_over_deferred(defer.gatherResults(
                 cached.values(),
                 consumeErrors=True,
-            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
 
         obj.__dict__[self.orig.__name__] = wrapped
 
         return wrapped
 
 
-def cached(max_entries=1000, num_args=1, lru=True):
+def cached(max_entries=1000, num_args=1, lru=True, tree=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
-        lru=lru
+        lru=lru,
+        tree=tree,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
         lru=lru,
+        tree=tree,
         inlineCallbacks=True,
     )