summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-10-19 15:00:12 +0100
committerGitHub <noreply@github.com>2020-10-19 15:00:12 +0100
commit903d11c43a5df9f704e5dad4d14506a6470524fc (patch)
tree69a1e7d0bc51ba72f0f9026ca66759a0200f0c6c
parentInclude a simple message in email notifications that include encrypted conten... (diff)
downloadsynapse-903d11c43a5df9f704e5dad4d14506a6470524fc.tar.xz
Add `DeferredCache.get_immediate` method (#8568)
* Add `DeferredCache.get_immediate` method

A bunch of things that are currently calling `DeferredCache.get` are only
really interested in the result if it's completed. We can optimise and simplify
this case.

* Remove unused 'default' parameter to DeferredCache.get()

* another get_immediate instance
-rw-r--r--changelog.d/8568.misc1
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py2
-rw-r--r--synapse/storage/databases/main/pusher.py2
-rw-r--r--synapse/storage/databases/main/receipts.py11
-rw-r--r--synapse/storage/databases/main/roommember.py2
-rw-r--r--synapse/util/caches/deferred_cache.py35
-rw-r--r--tests/util/caches/test_deferred_cache.py27
7 files changed, 53 insertions, 27 deletions
diff --git a/changelog.d/8568.misc b/changelog.d/8568.misc
new file mode 100644
index 0000000000..0ed7db92d3
--- /dev/null
+++ b/changelog.d/8568.misc
@@ -0,0 +1 @@
+Add `get_immediate` method to `DeferredCache`.
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index c440f2545c..a701defcdd 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -496,6 +496,6 @@ class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
     # dedupe when we add callbacks to lru cache nodes, otherwise the number
     # of callbacks would grow.
     def __call__(self):
-        rules = self.cache.get(self.room_id, None, update_metrics=False)
+        rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
         if rules:
             rules.invalidate_all()
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df8609b97b..7997242d90 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore):
                 lock=False,
             )
 
-            user_has_pusher = self.get_if_user_has_pusher.cache.get(
+            user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
                 (user_id,), None, update_metrics=False
             )
 
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 5cdf16521c..ca7917c989 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -25,7 +25,6 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
-from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -413,18 +412,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
         if receipt_type != "m.read":
             return
 
-        # Returns either an ObservableDeferred or the raw result
-        res = self.get_users_with_read_receipts_in_room.cache.get(
+        res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
             room_id, None, update_metrics=False
         )
 
-        # first handle the ObservableDeferred case
-        if isinstance(res, ObservableDeferred):
-            if res.has_called():
-                res = res.get_result()
-            else:
-                res = None
-
         if res and user_id in res:
             # We'd only be adding to the set, so no point invalidating if the
             # user is already there
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 20fcdaa529..9b08b49862 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -531,7 +531,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # If we do then we can reuse that result and simply update it with
             # any membership changes in `delta_ids`
             if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get(
+                prev_res = self._get_joined_users_from_context.cache.get_immediate(
                     (room_id, context.prev_group), None
                 )
                 if prev_res and isinstance(prev_res, dict):
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 4026e1f8fa..faeef75506 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -17,7 +17,16 @@
 
 import enum
 import threading
-from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast
+from typing import (
+    Callable,
+    Generic,
+    Iterable,
+    MutableMapping,
+    Optional,
+    TypeVar,
+    Union,
+    cast,
+)
 
 from prometheus_client import Gauge
 
@@ -33,7 +42,7 @@ cache_pending_metric = Gauge(
     ["name"],
 )
 
-
+T = TypeVar("T")
 KT = TypeVar("KT")
 VT = TypeVar("VT")
 
@@ -119,21 +128,21 @@ class DeferredCache(Generic[KT, VT]):
     def get(
         self,
         key: KT,
-        default=_Sentinel.sentinel,
         callback: Optional[Callable[[], None]] = None,
         update_metrics: bool = True,
-    ):
+    ) -> Union[ObservableDeferred, VT]:
         """Looks the key up in the caches.
 
         Args:
             key(tuple)
-            default: What is returned if key is not in the caches. If not
-                specified then function throws KeyError instead
             callback(fn): Gets called when the entry in the cache is invalidated
             update_metrics (bool): whether to update the cache hit rate metrics
 
         Returns:
             Either an ObservableDeferred or the result itself
+
+        Raises:
+            KeyError if the key is not found in the cache
         """
         callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
@@ -145,13 +154,19 @@ class DeferredCache(Generic[KT, VT]):
                 m.inc_hits()
             return val.deferred
 
-        val = self.cache.get(
-            key, default, callbacks=callbacks, update_metrics=update_metrics
+        val2 = self.cache.get(
+            key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
         )
-        if val is _Sentinel.sentinel:
+        if val2 is _Sentinel.sentinel:
             raise KeyError()
         else:
-            return val
+            return val2
+
+    def get_immediate(
+        self, key: KT, default: T, update_metrics: bool = True
+    ) -> Union[VT, T]:
+        """If we have a *completed* cached value, return it."""
+        return self.cache.get(key, default, update_metrics=update_metrics)
 
     def set(
         self,
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 9717be56b6..8a08ab6661 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -38,6 +38,22 @@ class DeferredCacheTestCase(unittest.TestCase):
 
         self.assertEquals(cache.get("foo"), 123)
 
+    def test_get_immediate(self):
+        cache = DeferredCache("test")
+        d1 = defer.Deferred()
+        cache.set("key1", d1)
+
+        # get_immediate should return default
+        v = cache.get_immediate("key1", 1)
+        self.assertEqual(v, 1)
+
+        # now complete the set
+        d1.callback(2)
+
+        # get_immediate should return result
+        v = cache.get_immediate("key1", 1)
+        self.assertEqual(v, 2)
+
     def test_invalidate(self):
         cache = DeferredCache("test")
         cache.prefill(("foo",), 123)
@@ -80,9 +96,11 @@ class DeferredCacheTestCase(unittest.TestCase):
         # now do the invalidation
         cache.invalidate_all()
 
-        # lookup should return none
-        self.assertIsNone(cache.get("key1", None))
-        self.assertIsNone(cache.get("key2", None))
+        # lookup should fail
+        with self.assertRaises(KeyError):
+            cache.get("key1")
+        with self.assertRaises(KeyError):
+            cache.get("key2")
 
         # both callbacks should have been callbacked
         self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
@@ -90,7 +108,8 @@ class DeferredCacheTestCase(unittest.TestCase):
 
         # letting the other lookup complete should do nothing
         d1.callback("result1")
-        self.assertIsNone(cache.get("key1", None))
+        with self.assertRaises(KeyError):
+            cache.get("key1", None)
 
     def test_eviction(self):
         cache = DeferredCache(