summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async_helpers.py22
-rw-r--r--synapse/util/caches/response_cache.py127
2 files changed, 113 insertions, 36 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index bde99ea878..150a04b53e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import collections
 import inspect
 import itertools
@@ -57,7 +58,26 @@ logger = logging.getLogger(__name__)
 _T = TypeVar("_T")
 
 
-class ObservableDeferred(Generic[_T]):
+class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta):
+    """Abstract base class defining the consumer interface of ObservableDeferred"""
+
+    __slots__ = ()
+
+    @abc.abstractmethod
+    def observe(self) -> "defer.Deferred[_T]":
+        """Add a new observer for this ObservableDeferred
+
+        This returns a brand new deferred that is resolved when the underlying
+        deferred is resolved. Interacting with the returned deferred does not
+        effect the underlying deferred.
+
+        Note that the returned Deferred doesn't follow the Synapse logcontext rules -
+        you will probably want to `make_deferred_yieldable` it.
+        """
+        ...
+
+
+class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
     """Wraps a deferred object so that we can add observer deferreds. These
     observer deferreds do not affect the callback chain of the original
     deferred.
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 88ccf44337..a3eb5f741b 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,19 +12,37 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Generic,
+    Iterable,
+    Optional,
+    TypeVar,
+)
 
 import attr
 
 from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import (
+    active_span,
+    start_active_span,
+    start_active_span_follows_from,
+)
 from synapse.util import Clock
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred
 from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
 
+if TYPE_CHECKING:
+    import opentracing
+
 # the type of the key in the cache
 KV = TypeVar("KV")
 
@@ -54,6 +72,20 @@ class ResponseCacheContext(Generic[KV]):
     """
 
 
+@attr.s(auto_attribs=True)
+class ResponseCacheEntry:
+    result: AbstractObservableDeferred
+    """The (possibly incomplete) result of the operation.
+
+    Note that we continue to store an ObservableDeferred even after the operation
+    completes (rather than switching to an immediate value), since that makes it
+    easier to cache Failure results.
+    """
+
+    opentracing_span_context: "Optional[opentracing.SpanContext]"
+    """The opentracing span which generated/is generating the result"""
+
+
 class ResponseCache(Generic[KV]):
     """
     This caches a deferred response. Until the deferred completes it will be
@@ -63,10 +95,7 @@ class ResponseCache(Generic[KV]):
     """
 
     def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
-        # This is poorly-named: it includes both complete and incomplete results.
-        # We keep complete results rather than switching to absolute values because
-        # that makes it easier to cache Failure results.
-        self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
+        self._result_cache: Dict[KV, ResponseCacheEntry] = {}
 
         self.clock = clock
         self.timeout_sec = timeout_ms / 1000.0
@@ -75,56 +104,63 @@ class ResponseCache(Generic[KV]):
         self._metrics = register_cache("response_cache", name, self, resizable=False)
 
     def size(self) -> int:
-        return len(self.pending_result_cache)
+        return len(self._result_cache)
 
     def __len__(self) -> int:
         return self.size()
 
-    def get(self, key: KV) -> Optional[defer.Deferred]:
-        """Look up the given key.
+    def keys(self) -> Iterable[KV]:
+        """Get the keys currently in the result cache
 
-        Returns a new Deferred (which also doesn't follow the synapse
-        logcontext rules). You will probably want to make_deferred_yieldable the result.
+        Returns both incomplete entries, and (if the timeout on this cache is non-zero),
+        complete entries which are still in the cache.
 
-        If there is no entry for the key, returns None.
+        Note that the returned iterator is not safe in the face of concurrent execution:
+        behaviour is undefined if `wrap` is called during iteration.
+        """
+        return self._result_cache.keys()
+
+    def _get(self, key: KV) -> Optional[ResponseCacheEntry]:
+        """Look up the given key.
 
         Args:
-            key: key to get/set in the cache
+            key: key to get in the cache
 
         Returns:
-            None if there is no entry for this key; otherwise a deferred which
-            resolves to the result.
+            The entry for this key, if any; else None.
         """
-        result = self.pending_result_cache.get(key)
-        if result is not None:
+        entry = self._result_cache.get(key)
+        if entry is not None:
             self._metrics.inc_hits()
-            return result.observe()
+            return entry
         else:
             self._metrics.inc_misses()
             return None
 
     def _set(
-        self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]"
-    ) -> "defer.Deferred[RV]":
+        self,
+        context: ResponseCacheContext[KV],
+        deferred: "defer.Deferred[RV]",
+        opentracing_span_context: "Optional[opentracing.SpanContext]",
+    ) -> ResponseCacheEntry:
         """Set the entry for the given key to the given deferred.
 
         *deferred* should run its callbacks in the sentinel logcontext (ie,
         you should wrap normal synapse deferreds with
         synapse.logging.context.run_in_background).
 
-        Returns a new Deferred (which also doesn't follow the synapse logcontext rules).
-        You will probably want to make_deferred_yieldable the result.
-
         Args:
             context: Information about the cache miss
             deferred: The deferred which resolves to the result.
+            opentracing_span_context: An opentracing span wrapping the calculation
 
         Returns:
-            A new deferred which resolves to the actual result.
+            The cache entry object.
         """
         result = ObservableDeferred(deferred, consumeErrors=True)
         key = context.cache_key
-        self.pending_result_cache[key] = result
+        entry = ResponseCacheEntry(result, opentracing_span_context)
+        self._result_cache[key] = entry
 
         def on_complete(r: RV) -> RV:
             # if this cache has a non-zero timeout, and the callback has not cleared
@@ -132,18 +168,18 @@ class ResponseCache(Generic[KV]):
             # its removal later.
             if self.timeout_sec and context.should_cache:
                 self.clock.call_later(
-                    self.timeout_sec, self.pending_result_cache.pop, key, None
+                    self.timeout_sec, self._result_cache.pop, key, None
                 )
             else:
                 # otherwise, remove the result immediately.
-                self.pending_result_cache.pop(key, None)
+                self._result_cache.pop(key, None)
             return r
 
-        # make sure we do this *after* adding the entry to pending_result_cache,
+        # make sure we do this *after* adding the entry to result_cache,
         # in case the result is already complete (in which case flipping the order would
         # leave us with a stuck entry in the cache).
         result.addBoth(on_complete)
-        return result.observe()
+        return entry
 
     async def wrap(
         self,
@@ -189,20 +225,41 @@ class ResponseCache(Generic[KV]):
         Returns:
             The result of the callback (from the cache, or otherwise)
         """
-        result = self.get(key)
-        if not result:
+        entry = self._get(key)
+        if not entry:
             logger.debug(
                 "[%s]: no cached result for [%s], calculating new one", self._name, key
             )
             context = ResponseCacheContext(cache_key=key)
             if cache_context:
                 kwargs["cache_context"] = context
-            d = run_in_background(callback, *args, **kwargs)
-            result = self._set(context, d)
-        elif not isinstance(result, defer.Deferred) or result.called:
+
+            span_context: Optional[opentracing.SpanContext] = None
+
+            async def cb() -> RV:
+                # NB it is important that we do not `await` before setting span_context!
+                nonlocal span_context
+                with start_active_span(f"ResponseCache[{self._name}].calculate"):
+                    span = active_span()
+                    if span:
+                        span_context = span.context
+                    return await callback(*args, **kwargs)
+
+            d = run_in_background(cb)
+            entry = self._set(context, d, span_context)
+            return await make_deferred_yieldable(entry.result.observe())
+
+        result = entry.result.observe()
+        if result.called:
             logger.info("[%s]: using completed cached result for [%s]", self._name, key)
         else:
             logger.info(
                 "[%s]: using incomplete cached result for [%s]", self._name, key
             )
-        return await make_deferred_yieldable(result)
+
+        span_context = entry.opentracing_span_context
+        with start_active_span_follows_from(
+            f"ResponseCache[{self._name}].wait",
+            contexts=(span_context,) if span_context else (),
+        ):
+            return await make_deferred_yieldable(result)