diff --git a/changelog.d/11229.misc b/changelog.d/11229.misc
new file mode 100644
index 0000000000..7bb01cf079
--- /dev/null
+++ b/changelog.d/11229.misc
@@ -0,0 +1 @@
+`ObservableDeferred`: run registered observers in order.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 5df80ea8e7..96efc5f3e3 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -22,11 +22,11 @@ from typing import (
Any,
Awaitable,
Callable,
+ Collection,
Dict,
Generic,
Hashable,
Iterable,
- List,
Optional,
Set,
TypeVar,
@@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]):
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
- object.__setattr__(self, "_observers", set())
+ object.__setattr__(self, "_observers", [])
def callback(r):
object.__setattr__(self, "_result", (True, r))
- while self._observers:
- observer = self._observers.pop()
+
+ # once we have set _result, no more entries will be added to _observers,
+ # so it's safe to replace it with the empty tuple.
+ observers = self._observers
+ object.__setattr__(self, "_observers", ())
+
+ for observer in observers:
try:
observer.callback(r)
except Exception as e:
@@ -95,12 +100,16 @@ class ObservableDeferred(Generic[_T]):
def errback(f):
object.__setattr__(self, "_result", (False, f))
- while self._observers:
+
+ # once we have set _result, no more entries will be added to _observers,
+ # so it's safe to replace it with the empty tuple.
+ observers = self._observers
+ object.__setattr__(self, "_observers", ())
+
+ for observer in observers:
# This is a little bit of magic to correctly propagate stack
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
-
- observer = self._observers.pop()
try:
observer.errback(f)
except Exception as e:
@@ -127,20 +136,13 @@ class ObservableDeferred(Generic[_T]):
"""
if not self._result:
d: "defer.Deferred[_T]" = defer.Deferred()
-
- def remove(r):
- self._observers.discard(d)
- return r
-
- d.addBoth(remove)
-
- self._observers.add(d)
+ self._observers.append(d)
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
- def observers(self) -> "List[defer.Deferred[_T]]":
+ def observers(self) -> "Collection[defer.Deferred[_T]]":
return self._observers
def has_called(self) -> bool:
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 54a88a8325..c613ce3f10 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -47,9 +47,7 @@ class DeferredCacheTestCase(TestCase):
self.assertTrue(set_d.called)
return r
- # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
- # maybe we should fix that?
- # get_d.addCallback(check1)
+ get_d.addCallback(check1)
# now fire off all the deferreds
origin_d.callback(99)
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_helpers.py
index 069f875962..ab89cab812 100644
--- a/tests/util/test_async_utils.py
+++ b/tests/util/test_async_helpers.py
@@ -21,11 +21,78 @@ from synapse.logging.context import (
PreserveLoggingContext,
current_context,
)
-from synapse.util.async_helpers import timeout_deferred
+from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from tests.unittest import TestCase
+class ObservableDeferredTest(TestCase):
+ def test_succeed(self):
+ origin_d = Deferred()
+ observable = ObservableDeferred(origin_d)
+
+ observer1 = observable.observe()
+ observer2 = observable.observe()
+
+ self.assertFalse(observer1.called)
+ self.assertFalse(observer2.called)
+
+ # check the first observer is called first
+ def check_called_first(res):
+ self.assertFalse(observer2.called)
+ return res
+
+ observer1.addBoth(check_called_first)
+
+ # store the results
+ results = [None, None]
+
+ def check_val(res, idx):
+ results[idx] = res
+ return res
+
+ observer1.addCallback(check_val, 0)
+ observer2.addCallback(check_val, 1)
+
+ origin_d.callback(123)
+ self.assertEqual(results[0], 123, "observer 1 callback result")
+ self.assertEqual(results[1], 123, "observer 2 callback result")
+
+ def test_failure(self):
+ origin_d = Deferred()
+ observable = ObservableDeferred(origin_d, consumeErrors=True)
+
+ observer1 = observable.observe()
+ observer2 = observable.observe()
+
+ self.assertFalse(observer1.called)
+ self.assertFalse(observer2.called)
+
+ # check the first observer is called first
+ def check_called_first(res):
+ self.assertFalse(observer2.called)
+ return res
+
+ observer1.addBoth(check_called_first)
+
+ # store the results
+ results = [None, None]
+
+ def check_val(res, idx):
+ results[idx] = res
+ return None
+
+ observer1.addErrback(check_val, 0)
+ observer2.addErrback(check_val, 1)
+
+ try:
+ raise Exception("gah!")
+ except Exception as e:
+ origin_d.errback(e)
+ self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
+ self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
+
+
class TimeoutDeferredTest(TestCase):
def setUp(self):
self.clock = Clock()
|