diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 88e56e3302..277854ccbc 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
+)
from . import caches_by_name, DEBUG_CACHES, cache_counter
@@ -149,7 +152,7 @@ class CacheDescriptor(object):
self.lru = lru
self.tree = tree
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+ self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
if len(self.arg_names) < self.num_args:
raise Exception(
@@ -190,7 +193,7 @@ class CacheDescriptor(object):
defer.returnValue(cached_result)
observer.addCallback(check_result)
- return observer
+ return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
@@ -198,6 +201,7 @@ class CacheDescriptor(object):
sequence = self.cache.sequence
ret = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
obj, *args, **kwargs
)
@@ -211,7 +215,7 @@ class CacheDescriptor(object):
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
- return ret.observe()
+ return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
@@ -250,7 +254,7 @@ class CacheListDescriptor(object):
self.num_args = num_args
self.list_name = list_name
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+ self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
@@ -299,6 +303,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
**args_to_call
)
@@ -308,7 +313,8 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
- observer = ret_d.observe()
+ with PreserveLoggingContext():
+ observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
@@ -327,10 +333,10 @@ class CacheListDescriptor(object):
cached[arg] = res
- return defer.gatherResults(
+ return preserve_context_over_deferred(defer.gatherResults(
cached.values(),
consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+ ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
obj.__dict__[self.orig.__name__] = wrapped
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 494226f5ea..62cae99649 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -55,7 +55,7 @@ class ExpiringCache(object):
def f():
self._prune_cache()
- self._clock.looping_call(f, self._expiry_ms/2)
+ self._clock.looping_call(f, self._expiry_ms / 2)
def __setitem__(self, key, value):
now = self._clock.time_msec()
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
index b1e40417fd..d03678b8c8 100644
--- a/synapse/util/caches/snapshot_cache.py
+++ b/synapse/util/caches/snapshot_cache.py
@@ -87,7 +87,8 @@ class SnapshotCache(object):
# expire from the rotation of that cache.
self.next_result_cache[key] = result
self.pending_result_cache.pop(key, None)
+ return r
- result.observe().addBoth(shuffle_along)
+ result.addBoth(shuffle_along)
return result.observe()
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index c673b1bdfc..b37f1c0725 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -32,7 +32,7 @@ class StreamChangeCache(object):
entities that may have changed since that position. If position key is too
old then the cache will simply return all given entities.
"""
- def __init__(self, name, current_stream_pos, max_size=10000):
+ def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
self._max_size = max_size
self._entity_to_key = {}
self._cache = sorteddict()
@@ -40,6 +40,9 @@ class StreamChangeCache(object):
self.name = name
caches_by_name[self.name] = self._cache
+ for entity, stream_pos in prefilled_cache.items():
+ self.entity_has_changed(entity, stream_pos)
+
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
"""
@@ -49,15 +52,10 @@ class StreamChangeCache(object):
cache_counter.inc_misses(self.name)
return True
- if stream_pos == self._earliest_known_stream_pos:
- # If the same as the earliest key, assume nothing has changed.
- cache_counter.inc_hits(self.name)
- return False
-
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
- cache_counter.inc_misses(self.name)
- return True
+ cache_counter.inc_hits(self.name)
+ return False
if stream_pos < latest_entity_change_pos:
cache_counter.inc_misses(self.name)
@@ -95,7 +93,7 @@ class StreamChangeCache(object):
if stream_pos > self._earliest_known_stream_pos:
old_pos = self._entity_to_key.get(entity, None)
- if old_pos:
+ if old_pos is not None:
stream_pos = max(stream_pos, old_pos)
self._cache.pop(old_pos, None)
self._cache[stream_pos] = entity
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 29d02f7e95..03bc1401b7 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -58,7 +58,7 @@ class TreeCache(object):
if n:
break
- node_and_keys[i+1][0].pop(k)
+ node_and_keys[i + 1][0].pop(k)
popped, cnt = _strip_and_count_entires(popped)
self.size -= cnt
|