summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async_helpers.py34
1 files changed, 18 insertions, 16 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 5df80ea8e7..96efc5f3e3 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -22,11 +22,11 @@ from typing import (
     Any,
     Awaitable,
     Callable,
+    Collection,
     Dict,
     Generic,
     Hashable,
     Iterable,
-    List,
     Optional,
     Set,
     TypeVar,
@@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]):
     def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
-        object.__setattr__(self, "_observers", set())
+        object.__setattr__(self, "_observers", [])
 
         def callback(r):
             object.__setattr__(self, "_result", (True, r))
-            while self._observers:
-                observer = self._observers.pop()
+
+            # once we have set _result, no more entries will be added to _observers,
+            # so it's safe to replace it with the empty tuple.
+            observers = self._observers
+            object.__setattr__(self, "_observers", ())
+
+            for observer in observers:
                 try:
                     observer.callback(r)
                 except Exception as e:
@@ -95,12 +100,16 @@ class ObservableDeferred(Generic[_T]):
 
         def errback(f):
             object.__setattr__(self, "_result", (False, f))
-            while self._observers:
+
+            # once we have set _result, no more entries will be added to _observers,
+            # so it's safe to replace it with the empty tuple.
+            observers = self._observers
+            object.__setattr__(self, "_observers", ())
+
+            for observer in observers:
                 # This is a little bit of magic to correctly propagate stack
                 # traces when we `await` on one of the observer deferreds.
                 f.value.__failure__ = f
-
-                observer = self._observers.pop()
                 try:
                     observer.errback(f)
                 except Exception as e:
@@ -127,20 +136,13 @@ class ObservableDeferred(Generic[_T]):
         """
         if not self._result:
             d: "defer.Deferred[_T]" = defer.Deferred()
-
-            def remove(r):
-                self._observers.discard(d)
-                return r
-
-            d.addBoth(remove)
-
-            self._observers.add(d)
+            self._observers.append(d)
             return d
         else:
             success, res = self._result
             return defer.succeed(res) if success else defer.fail(res)
 
-    def observers(self) -> "List[defer.Deferred[_T]]":
+    def observers(self) -> "Collection[defer.Deferred[_T]]":
         return self._observers
 
     def has_called(self) -> bool: