diff --git a/changelog.d/8568.misc b/changelog.d/8568.misc
new file mode 100644
index 0000000000..0ed7db92d3
--- /dev/null
+++ b/changelog.d/8568.misc
@@ -0,0 +1 @@
+Add `get_immediate` method to `DeferredCache`.
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index c440f2545c..a701defcdd 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -496,6 +496,6 @@ class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
# dedupe when we add callbacks to lru cache nodes, otherwise the number
# of callbacks would grow.
def __call__(self):
- rules = self.cache.get(self.room_id, None, update_metrics=False)
+ rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df8609b97b..7997242d90 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore):
lock=False,
)
- user_has_pusher = self.get_if_user_has_pusher.cache.get(
+ user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
(user_id,), None, update_metrics=False
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 5cdf16521c..ca7917c989 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -25,7 +25,6 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
-from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -413,18 +412,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
if receipt_type != "m.read":
return
- # Returns either an ObservableDeferred or the raw result
- res = self.get_users_with_read_receipts_in_room.cache.get(
+ res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
room_id, None, update_metrics=False
)
- # first handle the ObservableDeferred case
- if isinstance(res, ObservableDeferred):
- if res.has_called():
- res = res.get_result()
- else:
- res = None
-
if res and user_id in res:
# We'd only be adding to the set, so no point invalidating if the
# user is already there
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 20fcdaa529..9b08b49862 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -531,7 +531,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# If we do then we can reuse that result and simply update it with
# any membership changes in `delta_ids`
if context.prev_group and context.delta_ids:
- prev_res = self._get_joined_users_from_context.cache.get(
+ prev_res = self._get_joined_users_from_context.cache.get_immediate(
(room_id, context.prev_group), None
)
if prev_res and isinstance(prev_res, dict):
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 4026e1f8fa..faeef75506 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -17,7 +17,16 @@
import enum
import threading
-from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast
+from typing import (
+ Callable,
+ Generic,
+ Iterable,
+ MutableMapping,
+ Optional,
+ TypeVar,
+ Union,
+ cast,
+)
from prometheus_client import Gauge
@@ -33,7 +42,7 @@ cache_pending_metric = Gauge(
["name"],
)
-
+T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")
@@ -119,21 +128,21 @@ class DeferredCache(Generic[KT, VT]):
def get(
self,
key: KT,
- default=_Sentinel.sentinel,
callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True,
- ):
+ ) -> Union[ObservableDeferred, VT]:
"""Looks the key up in the caches.
Args:
key(tuple)
- default: What is returned if key is not in the caches. If not
- specified then function throws KeyError instead
callback(fn): Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either an ObservableDeferred or the result itself
+
+ Raises:
+ KeyError if the key is not found in the cache
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
@@ -145,13 +154,19 @@ class DeferredCache(Generic[KT, VT]):
m.inc_hits()
return val.deferred
- val = self.cache.get(
- key, default, callbacks=callbacks, update_metrics=update_metrics
+ val2 = self.cache.get(
+ key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
)
- if val is _Sentinel.sentinel:
+ if val2 is _Sentinel.sentinel:
raise KeyError()
else:
- return val
+ return val2
+
+ def get_immediate(
+ self, key: KT, default: T, update_metrics: bool = True
+ ) -> Union[VT, T]:
+ """If we have a *completed* cached value, return it."""
+ return self.cache.get(key, default, update_metrics=update_metrics)
def set(
self,
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 9717be56b6..8a08ab6661 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -38,6 +38,22 @@ class DeferredCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get("foo"), 123)
+ def test_get_immediate(self):
+ cache = DeferredCache("test")
+ d1 = defer.Deferred()
+ cache.set("key1", d1)
+
+ # get_immediate should return default
+ v = cache.get_immediate("key1", 1)
+ self.assertEqual(v, 1)
+
+ # now complete the set
+ d1.callback(2)
+
+ # get_immediate should return result
+ v = cache.get_immediate("key1", 1)
+ self.assertEqual(v, 2)
+
def test_invalidate(self):
cache = DeferredCache("test")
cache.prefill(("foo",), 123)
@@ -80,9 +96,11 @@ class DeferredCacheTestCase(unittest.TestCase):
# now do the invalidation
cache.invalidate_all()
- # lookup should return none
- self.assertIsNone(cache.get("key1", None))
- self.assertIsNone(cache.get("key2", None))
+ # lookup should fail
+ with self.assertRaises(KeyError):
+ cache.get("key1")
+ with self.assertRaises(KeyError):
+ cache.get("key2")
# both callbacks should have been callbacked
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
@@ -90,7 +108,8 @@ class DeferredCacheTestCase(unittest.TestCase):
# letting the other lookup complete should do nothing
d1.callback("result1")
- self.assertIsNone(cache.get("key1", None))
+ with self.assertRaises(KeyError):
+ cache.get("key1", None)
def test_eviction(self):
cache = DeferredCache(
|