diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 54a88a8325..c613ce3f10 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -47,9 +47,7 @@ class DeferredCacheTestCase(TestCase):
self.assertTrue(set_d.called)
return r
- # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
- # maybe we should fix that?
- # get_d.addCallback(check1)
+ get_d.addCallback(check1)
# now fire off all the deferreds
origin_d.callback(99)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 39947a166b..ced3efd93f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -17,6 +17,7 @@ from typing import Set
from unittest import mock
from twisted.internet import defer, reactor
+from twisted.internet.defer import Deferred
from synapse.api.errors import SynapseError
from synapse.logging.context import (
@@ -703,6 +704,48 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.mock.assert_called_once_with((40,), 2)
self.assertEqual(r, {10: "fish", 40: "gravy"})
+ def test_concurrent_lookups(self):
+ """All concurrent lookups should get the same result"""
+
+ class Cls:
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1):
+ pass
+
+ @descriptors.cachedList("fn", "args1")
+ def list_fn(self, args1) -> "Deferred[dict]":
+ return self.mock(args1)
+
+ obj = Cls()
+ deferred_result = Deferred()
+ obj.mock.return_value = deferred_result
+
+ # start off several concurrent lookups of the same key
+ d1 = obj.list_fn([10])
+ d2 = obj.list_fn([10])
+ d3 = obj.list_fn([10])
+
+ # the mock should have been called exactly once
+ obj.mock.assert_called_once_with((10,))
+ obj.mock.reset_mock()
+
+ # ... and none of the calls should yet be complete
+ self.assertFalse(d1.called)
+ self.assertFalse(d2.called)
+ self.assertFalse(d3.called)
+
+ # complete the lookup. @cachedList functions need to complete with a map
+ # of input->result
+ deferred_result.callback({10: "peas"})
+
+ # ... which should give the right result to all the callers
+ self.assertEqual(self.successResultOf(d1), {10: "peas"})
+ self.assertEqual(self.successResultOf(d2), {10: "peas"})
+ self.assertEqual(self.successResultOf(d3), {10: "peas"})
+
@defer.inlineCallbacks
def test_invalidate(self):
"""Make sure that invalidation callbacks are called."""
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_helpers.py
index 069f875962..ab89cab812 100644
--- a/tests/util/test_async_utils.py
+++ b/tests/util/test_async_helpers.py
@@ -21,11 +21,78 @@ from synapse.logging.context import (
PreserveLoggingContext,
current_context,
)
-from synapse.util.async_helpers import timeout_deferred
+from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from tests.unittest import TestCase
+class ObservableDeferredTest(TestCase):
+ def test_succeed(self):
+ origin_d = Deferred()
+ observable = ObservableDeferred(origin_d)
+
+ observer1 = observable.observe()
+ observer2 = observable.observe()
+
+ self.assertFalse(observer1.called)
+ self.assertFalse(observer2.called)
+
+ # check the first observer is called first
+ def check_called_first(res):
+ self.assertFalse(observer2.called)
+ return res
+
+ observer1.addBoth(check_called_first)
+
+ # store the results
+ results = [None, None]
+
+ def check_val(res, idx):
+ results[idx] = res
+ return res
+
+ observer1.addCallback(check_val, 0)
+ observer2.addCallback(check_val, 1)
+
+ origin_d.callback(123)
+ self.assertEqual(results[0], 123, "observer 1 callback result")
+ self.assertEqual(results[1], 123, "observer 2 callback result")
+
+ def test_failure(self):
+ origin_d = Deferred()
+ observable = ObservableDeferred(origin_d, consumeErrors=True)
+
+ observer1 = observable.observe()
+ observer2 = observable.observe()
+
+ self.assertFalse(observer1.called)
+ self.assertFalse(observer2.called)
+
+ # check the first observer is called first
+ def check_called_first(res):
+ self.assertFalse(observer2.called)
+ return res
+
+ observer1.addBoth(check_called_first)
+
+ # store the results
+ results = [None, None]
+
+ def check_val(res, idx):
+ results[idx] = res
+ return None
+
+ observer1.addErrback(check_val, 0)
+ observer2.addErrback(check_val, 1)
+
+ try:
+ raise Exception("gah!")
+ except Exception as e:
+ origin_d.errback(e)
+ self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
+ self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
+
+
class TimeoutDeferredTest(TestCase):
def setUp(self):
self.clock = Clock()
|