summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10157.misc1
-rw-r--r--synapse/replication/http/_base.py6
-rw-r--r--synapse/replication/http/membership.py2
-rw-r--r--synapse/util/caches/response_cache.py99
-rw-r--r--tests/util/caches/test_response_cache.py (renamed from tests/util/caches/test_responsecache.py)75
5 files changed, 146 insertions, 37 deletions
diff --git a/changelog.d/10157.misc b/changelog.d/10157.misc
new file mode 100644
index 0000000000..6c1d0e6e59
--- /dev/null
+++ b/changelog.d/10157.misc
@@ -0,0 +1 @@
+Extend `ResponseCache` to pass a context object into the callback.
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2a13026e9a..f13a7c23b4 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -285,7 +285,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             self.__class__.__name__,
         )
 
-    def _check_auth_and_handle(self, request, **kwargs):
+    async def _check_auth_and_handle(self, request, **kwargs):
         """Called on new incoming requests when caching is enabled. Checks
         if there is a cached response for the request and returns that,
         otherwise calls `_handle_request` and caches its response.
@@ -300,8 +300,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         if self.CACHE:
             txn_id = kwargs.pop("txn_id")
 
-            return self.response_cache.wrap(
+            return await self.response_cache.wrap(
                 txn_id, self._handle_request, request, **kwargs
             )
 
-        return self._handle_request(request, **kwargs)
+        return await self._handle_request(request, **kwargs)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 043c25f63d..34206c5060 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -345,7 +345,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
 
         return {}
 
-    def _handle_request(  # type: ignore
+    async def _handle_request(  # type: ignore
         self, request: Request, room_id: str, user_id: str, change: str
     ) -> Tuple[int, JsonDict]:
         logger.info("user membership change: %s in %s", user_id, room_id)
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 25ea1bcc91..34c662c4db 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,7 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, Dict, Generic, Optional, TypeVar
+from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar
+
+import attr
 
 from twisted.internet import defer
 
@@ -23,10 +25,36 @@ from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
 
-T = TypeVar("T")
+# the type of the key in the cache
+KV = TypeVar("KV")
+
+# the type of the result from the operation
+RV = TypeVar("RV")
+
 
+@attr.s(auto_attribs=True)
+class ResponseCacheContext(Generic[KV]):
+    """Information about a missed ResponseCache hit
 
-class ResponseCache(Generic[T]):
+    This object can be passed into the callback for additional feedback
+    """
+
+    cache_key: KV
+    """The cache key that caused the cache miss
+
+    This should be considered read-only.
+
+    TODO: in attrs 20.1, make it frozen with an on_setattr.
+    """
+
+    should_cache: bool = True
+    """Whether the result should be cached once the request completes.
+
+    This can be modified by the callback if it decides its result should not be cached.
+    """
+
+
+class ResponseCache(Generic[KV]):
     """
     This caches a deferred response. Until the deferred completes it will be
     returned from the cache. This means that if the client retries the request
@@ -35,8 +63,10 @@ class ResponseCache(Generic[T]):
     """
 
     def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
-        # Requests that haven't finished yet.
-        self.pending_result_cache = {}  # type: Dict[T, ObservableDeferred]
+        # This is poorly-named: it includes both complete and incomplete results.
+        # We keep complete results rather than switching to absolute values because
+        # that makes it easier to cache Failure results.
+        self.pending_result_cache = {}  # type: Dict[KV, ObservableDeferred]
 
         self.clock = clock
         self.timeout_sec = timeout_ms / 1000.0
@@ -50,16 +80,13 @@ class ResponseCache(Generic[T]):
     def __len__(self) -> int:
         return self.size()
 
-    def get(self, key: T) -> Optional[defer.Deferred]:
+    def get(self, key: KV) -> Optional[defer.Deferred]:
         """Look up the given key.
 
-        Can return either a new Deferred (which also doesn't follow the synapse
-        logcontext rules), or, if the request has completed, the actual
-        result. You will probably want to make_deferred_yieldable the result.
+        Returns a new Deferred (which also doesn't follow the synapse
+        logcontext rules). You will probably want to make_deferred_yieldable the result.
 
-        If there is no entry for the key, returns None. It is worth noting that
-        this means there is no way to distinguish a completed result of None
-        from an absent cache entry.
+        If there is no entry for the key, returns None.
 
         Args:
             key: key to get/set in the cache
@@ -76,42 +103,56 @@ class ResponseCache(Generic[T]):
             self._metrics.inc_misses()
             return None
 
-    def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
+    def _set(
+        self, context: ResponseCacheContext[KV], deferred: defer.Deferred
+    ) -> defer.Deferred:
         """Set the entry for the given key to the given deferred.
 
         *deferred* should run its callbacks in the sentinel logcontext (ie,
         you should wrap normal synapse deferreds with
         synapse.logging.context.run_in_background).
 
-        Can return either a new Deferred (which also doesn't follow the synapse
-        logcontext rules), or, if *deferred* was already complete, the actual
-        result. You will probably want to make_deferred_yieldable the result.
+        Returns a new Deferred (which also doesn't follow the synapse logcontext rules).
+        You will probably want to make_deferred_yieldable the result.
 
         Args:
-            key: key to get/set in the cache
+            context: Information about the cache miss
             deferred: The deferred which resolves to the result.
 
         Returns:
             A new deferred which resolves to the actual result.
         """
         result = ObservableDeferred(deferred, consumeErrors=True)
+        key = context.cache_key
         self.pending_result_cache[key] = result
 
-        def remove(r):
-            if self.timeout_sec:
+        def on_complete(r):
+            # if this cache has a non-zero timeout, and the callback has not cleared
+            # the should_cache bit, we leave it in the cache for now and schedule
+            # its removal later.
+            if self.timeout_sec and context.should_cache:
                 self.clock.call_later(
                     self.timeout_sec, self.pending_result_cache.pop, key, None
                 )
             else:
+                # otherwise, remove the result immediately.
                 self.pending_result_cache.pop(key, None)
             return r
 
-        result.addBoth(remove)
+        # make sure we do this *after* adding the entry to pending_result_cache,
+        # in case the result is already complete (in which case flipping the order would
+        # leave us with a stuck entry in the cache).
+        result.addBoth(on_complete)
         return result.observe()
 
-    def wrap(
-        self, key: T, callback: Callable[..., Any], *args: Any, **kwargs: Any
-    ) -> defer.Deferred:
+    async def wrap(
+        self,
+        key: KV,
+        callback: Callable[..., Awaitable[RV]],
+        *args: Any,
+        cache_context: bool = False,
+        **kwargs: Any,
+    ) -> RV:
         """Wrap together a *get* and *set* call, taking care of logcontexts
 
         First looks up the key in the cache, and if it is present makes it
@@ -140,22 +181,28 @@ class ResponseCache(Generic[T]):
 
             *args: positional parameters to pass to the callback, if it is used
 
+            cache_context: if set, the callback will be given a `cache_context` kw arg,
+                which will be a ResponseCacheContext object.
+
             **kwargs: named parameters to pass to the callback, if it is used
 
         Returns:
-            Deferred which resolves to the result
+            The result of the callback (from the cache, or otherwise)
         """
         result = self.get(key)
         if not result:
             logger.debug(
                 "[%s]: no cached result for [%s], calculating new one", self._name, key
             )
+            context = ResponseCacheContext(cache_key=key)
+            if cache_context:
+                kwargs["cache_context"] = context
             d = run_in_background(callback, *args, **kwargs)
-            result = self.set(key, d)
+            result = self._set(context, d)
         elif not isinstance(result, defer.Deferred) or result.called:
             logger.info("[%s]: using completed cached result for [%s]", self._name, key)
         else:
             logger.info(
                 "[%s]: using incomplete cached result for [%s]", self._name, key
             )
-        return make_deferred_yieldable(result)
+        return await make_deferred_yieldable(result)
diff --git a/tests/util/caches/test_responsecache.py b/tests/util/caches/test_response_cache.py
index f9a187b8de..1e83ef2f33 100644
--- a/tests/util/caches/test_responsecache.py
+++ b/tests/util/caches/test_response_cache.py
@@ -11,14 +11,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from parameterized import parameterized
 
-from synapse.util.caches.response_cache import ResponseCache
+from twisted.internet import defer
+
+from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
 
 from tests.server import get_clock
 from tests.unittest import TestCase
 
 
-class DeferredCacheTestCase(TestCase):
+class ResponseCacheTestCase(TestCase):
     """
     A TestCase class for ResponseCache.
 
@@ -48,7 +51,9 @@ class DeferredCacheTestCase(TestCase):
 
         expected_result = "howdy"
 
-        wrap_d = cache.wrap(0, self.instant_return, expected_result)
+        wrap_d = defer.ensureDeferred(
+            cache.wrap(0, self.instant_return, expected_result)
+        )
 
         self.assertEqual(
             expected_result,
@@ -66,7 +71,9 @@ class DeferredCacheTestCase(TestCase):
 
         expected_result = "howdy"
 
-        wrap_d = cache.wrap(0, self.instant_return, expected_result)
+        wrap_d = defer.ensureDeferred(
+            cache.wrap(0, self.instant_return, expected_result)
+        )
 
         self.assertEqual(
             expected_result,
@@ -80,7 +87,9 @@ class DeferredCacheTestCase(TestCase):
 
         expected_result = "howdy"
 
-        wrap_d = cache.wrap(0, self.instant_return, expected_result)
+        wrap_d = defer.ensureDeferred(
+            cache.wrap(0, self.instant_return, expected_result)
+        )
 
         self.assertEqual(expected_result, self.successResultOf(wrap_d))
         self.assertEqual(
@@ -99,7 +108,10 @@ class DeferredCacheTestCase(TestCase):
 
         expected_result = "howdy"
 
-        wrap_d = cache.wrap(0, self.delayed_return, expected_result)
+        wrap_d = defer.ensureDeferred(
+            cache.wrap(0, self.delayed_return, expected_result)
+        )
+
         self.assertNoResult(wrap_d)
 
         # function wakes up, returns result
@@ -112,7 +124,9 @@ class DeferredCacheTestCase(TestCase):
 
         expected_result = "howdy"
 
-        wrap_d = cache.wrap(0, self.delayed_return, expected_result)
+        wrap_d = defer.ensureDeferred(
+            cache.wrap(0, self.delayed_return, expected_result)
+        )
         self.assertNoResult(wrap_d)
 
         # stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2
@@ -129,3 +143,50 @@ class DeferredCacheTestCase(TestCase):
         self.reactor.pump((2,))
 
         self.assertIsNone(cache.get(0), "cache should not have the result now")
+
+    @parameterized.expand([(True,), (False,)])
+    def test_cache_context_nocache(self, should_cache: bool):
+        """If the callback clears the should_cache bit, the result should not be cached"""
+        cache = self.with_cache("medium_cache", ms=3000)
+
+        expected_result = "howdy"
+
+        call_count = 0
+
+        async def non_caching(o: str, cache_context: ResponseCacheContext[int]):
+            nonlocal call_count
+            call_count += 1
+            await self.clock.sleep(1)
+            cache_context.should_cache = should_cache
+            return o
+
+        wrap_d = defer.ensureDeferred(
+            cache.wrap(0, non_caching, expected_result, cache_context=True)
+        )
+        # there should be no result to start with
+        self.assertNoResult(wrap_d)
+
+        # a second call should also return a pending deferred
+        wrap2_d = defer.ensureDeferred(
+            cache.wrap(0, non_caching, expected_result, cache_context=True)
+        )
+        self.assertNoResult(wrap2_d)
+
+        # and there should have been exactly one call
+        self.assertEqual(call_count, 1)
+
+        # let the call complete
+        self.reactor.advance(1)
+
+        # both results should have completed
+        self.assertEqual(expected_result, self.successResultOf(wrap_d))
+        self.assertEqual(expected_result, self.successResultOf(wrap2_d))
+
+        if should_cache:
+            self.assertEqual(
+                expected_result,
+                self.successResultOf(cache.get(0)),
+                "cache should still have the result",
+            )
+        else:
+            self.assertIsNone(cache.get(0), "cache should not have the result")