summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/util/caches/test_deferred_cache.py4
-rw-r--r--tests/util/test_async_helpers.py (renamed from tests/util/test_async_utils.py)69
2 files changed, 69 insertions, 4 deletions
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 54a88a8325..c613ce3f10 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -47,9 +47,7 @@ class DeferredCacheTestCase(TestCase):
             self.assertTrue(set_d.called)
             return r
 
-        # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
-        #   maybe we should fix that?
-        # get_d.addCallback(check1)
+        get_d.addCallback(check1)
 
         # now fire off all the deferreds
         origin_d.callback(99)
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_helpers.py
index 069f875962..ab89cab812 100644
--- a/tests/util/test_async_utils.py
+++ b/tests/util/test_async_helpers.py
@@ -21,11 +21,78 @@ from synapse.logging.context import (
     PreserveLoggingContext,
     current_context,
 )
-from synapse.util.async_helpers import timeout_deferred
+from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
 
 from tests.unittest import TestCase
 
 
+class ObservableDeferredTest(TestCase):
+    def test_succeed(self):
+        origin_d = Deferred()
+        observable = ObservableDeferred(origin_d)
+
+        observer1 = observable.observe()
+        observer2 = observable.observe()
+
+        self.assertFalse(observer1.called)
+        self.assertFalse(observer2.called)
+
+        # check the first observer is called first
+        def check_called_first(res):
+            self.assertFalse(observer2.called)
+            return res
+
+        observer1.addBoth(check_called_first)
+
+        # store the results
+        results = [None, None]
+
+        def check_val(res, idx):
+            results[idx] = res
+            return res
+
+        observer1.addCallback(check_val, 0)
+        observer2.addCallback(check_val, 1)
+
+        origin_d.callback(123)
+        self.assertEqual(results[0], 123, "observer 1 callback result")
+        self.assertEqual(results[1], 123, "observer 2 callback result")
+
+    def test_failure(self):
+        origin_d = Deferred()
+        observable = ObservableDeferred(origin_d, consumeErrors=True)
+
+        observer1 = observable.observe()
+        observer2 = observable.observe()
+
+        self.assertFalse(observer1.called)
+        self.assertFalse(observer2.called)
+
+        # check the first observer is called first
+        def check_called_first(res):
+            self.assertFalse(observer2.called)
+            return res
+
+        observer1.addBoth(check_called_first)
+
+        # store the results
+        results = [None, None]
+
+        def check_val(res, idx):
+            results[idx] = res
+            return None
+
+        observer1.addErrback(check_val, 0)
+        observer2.addErrback(check_val, 1)
+
+        try:
+            raise Exception("gah!")
+        except Exception as e:
+            origin_d.errback(e)
+        self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
+        self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
+
+
 class TimeoutDeferredTest(TestCase):
     def setUp(self):
         self.clock = Clock()