From 46d0937447479761a22a8c843f6ba51bbcdc914b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 2 Nov 2021 00:17:35 +0000 Subject: ObservableDeferred: run observers in order (#11229) --- changelog.d/11229.misc | 1 + synapse/util/async_helpers.py | 34 +++--- tests/util/caches/test_deferred_cache.py | 4 +- tests/util/test_async_helpers.py | 173 +++++++++++++++++++++++++++++++ tests/util/test_async_utils.py | 106 ------------------- 5 files changed, 193 insertions(+), 125 deletions(-) create mode 100644 changelog.d/11229.misc create mode 100644 tests/util/test_async_helpers.py delete mode 100644 tests/util/test_async_utils.py diff --git a/changelog.d/11229.misc b/changelog.d/11229.misc new file mode 100644 index 0000000000..7bb01cf079 --- /dev/null +++ b/changelog.d/11229.misc @@ -0,0 +1 @@ +`ObservableDeferred`: run registered observers in order. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 5df80ea8e7..96efc5f3e3 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -22,11 +22,11 @@ from typing import ( Any, Awaitable, Callable, + Collection, Dict, Generic, Hashable, Iterable, - List, Optional, Set, TypeVar, @@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]): def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) - object.__setattr__(self, "_observers", set()) + object.__setattr__(self, "_observers", []) def callback(r): object.__setattr__(self, "_result", (True, r)) - while self._observers: - observer = self._observers.pop() + + # once we have set _result, no more entries will be added to _observers, + # so it's safe to replace it with the empty tuple. + observers = self._observers + object.__setattr__(self, "_observers", ()) + + for observer in observers: try: observer.callback(r) except Exception as e: @@ -95,12 +100,16 @@ class ObservableDeferred(Generic[_T]): def errback(f): object.__setattr__(self, "_result", (False, f)) - while self._observers: + + # once we have set _result, no more entries will be added to _observers, + # so it's safe to replace it with the empty tuple. + observers = self._observers + object.__setattr__(self, "_observers", ()) + + for observer in observers: # This is a little bit of magic to correctly propagate stack # traces when we `await` on one of the observer deferreds. f.value.__failure__ = f - - observer = self._observers.pop() try: observer.errback(f) except Exception as e: @@ -127,20 +136,13 @@ class ObservableDeferred(Generic[_T]): """ if not self._result: d: "defer.Deferred[_T]" = defer.Deferred() - - def remove(r): - self._observers.discard(d) - return r - - d.addBoth(remove) - - self._observers.add(d) + self._observers.append(d) return d else: success, res = self._result return defer.succeed(res) if success else defer.fail(res) - def observers(self) -> "List[defer.Deferred[_T]]": + def observers(self) -> "Collection[defer.Deferred[_T]]": return self._observers def has_called(self) -> bool: diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 54a88a8325..c613ce3f10 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -47,9 +47,7 @@ class DeferredCacheTestCase(TestCase): self.assertTrue(set_d.called) return r - # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8. - # maybe we should fix that? - # get_d.addCallback(check1) + get_d.addCallback(check1) # now fire off all the deferreds origin_d.callback(99) diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py new file mode 100644 index 0000000000..ab89cab812 --- /dev/null +++ b/tests/util/test_async_helpers.py @@ -0,0 +1,173 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer +from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.task import Clock + +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + PreserveLoggingContext, + current_context, +) +from synapse.util.async_helpers import ObservableDeferred, timeout_deferred + +from tests.unittest import TestCase + + +class ObservableDeferredTest(TestCase): + def test_succeed(self): + origin_d = Deferred() + observable = ObservableDeferred(origin_d) + + observer1 = observable.observe() + observer2 = observable.observe() + + self.assertFalse(observer1.called) + self.assertFalse(observer2.called) + + # check the first observer is called first + def check_called_first(res): + self.assertFalse(observer2.called) + return res + + observer1.addBoth(check_called_first) + + # store the results + results = [None, None] + + def check_val(res, idx): + results[idx] = res + return res + + observer1.addCallback(check_val, 0) + observer2.addCallback(check_val, 1) + + origin_d.callback(123) + self.assertEqual(results[0], 123, "observer 1 callback result") + self.assertEqual(results[1], 123, "observer 2 callback result") + + def test_failure(self): + origin_d = Deferred() + observable = ObservableDeferred(origin_d, consumeErrors=True) + + observer1 = observable.observe() + observer2 = observable.observe() + + self.assertFalse(observer1.called) + self.assertFalse(observer2.called) + + # check the first observer is called first + def check_called_first(res): + self.assertFalse(observer2.called) + return res + + observer1.addBoth(check_called_first) + + # store the results + results = [None, None] + + def check_val(res, idx): + results[idx] = res + return None + + observer1.addErrback(check_val, 0) + observer2.addErrback(check_val, 1) + + try: + raise Exception("gah!") + except Exception as e: + origin_d.errback(e) + self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result") + self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result") + + +class TimeoutDeferredTest(TestCase): + def setUp(self): + self.clock = Clock() + + def test_times_out(self): + """Basic test case that checks that the original deferred is cancelled and that + the timing-out deferred is errbacked + """ + cancelled = [False] + + def canceller(_d): + cancelled[0] = True + + non_completing_d = Deferred(canceller) + timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) + + self.assertNoResult(timing_out_d) + self.assertFalse(cancelled[0], "deferred was cancelled prematurely") + + self.clock.pump((1.0,)) + + self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") + self.failureResultOf(timing_out_d, defer.TimeoutError) + + def test_times_out_when_canceller_throws(self): + """Test that we have successfully worked around + https://twistedmatrix.com/trac/ticket/9534""" + + def canceller(_d): + raise Exception("can't cancel this deferred") + + non_completing_d = Deferred(canceller) + timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) + + self.assertNoResult(timing_out_d) + + self.clock.pump((1.0,)) + + self.failureResultOf(timing_out_d, defer.TimeoutError) + + def test_logcontext_is_preserved_on_cancellation(self): + blocking_was_cancelled = [False] + + @defer.inlineCallbacks + def blocking(): + non_completing_d = Deferred() + with PreserveLoggingContext(): + try: + yield non_completing_d + except CancelledError: + blocking_was_cancelled[0] = True + raise + + with LoggingContext("one") as context_one: + # the errbacks should be run in the test logcontext + def errback(res, deferred_name): + self.assertIs( + current_context(), + context_one, + "errback %s run in unexpected logcontext %s" + % (deferred_name, current_context()), + ) + return res + + original_deferred = blocking() + original_deferred.addErrback(errback, "orig") + timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock) + self.assertNoResult(timing_out_d) + self.assertIs(current_context(), SENTINEL_CONTEXT) + timing_out_d.addErrback(errback, "timingout") + + self.clock.pump((1.0,)) + + self.assertTrue( + blocking_was_cancelled[0], "non-completing deferred was not cancelled" + ) + self.failureResultOf(timing_out_d, defer.TimeoutError) + self.assertIs(current_context(), context_one) diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py deleted file mode 100644 index 069f875962..0000000000 --- a/tests/util/test_async_utils.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from twisted.internet import defer -from twisted.internet.defer import CancelledError, Deferred -from twisted.internet.task import Clock - -from synapse.logging.context import ( - SENTINEL_CONTEXT, - LoggingContext, - PreserveLoggingContext, - current_context, -) -from synapse.util.async_helpers import timeout_deferred - -from tests.unittest import TestCase - - -class TimeoutDeferredTest(TestCase): - def setUp(self): - self.clock = Clock() - - def test_times_out(self): - """Basic test case that checks that the original deferred is cancelled and that - the timing-out deferred is errbacked - """ - cancelled = [False] - - def canceller(_d): - cancelled[0] = True - - non_completing_d = Deferred(canceller) - timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) - - self.assertNoResult(timing_out_d) - self.assertFalse(cancelled[0], "deferred was cancelled prematurely") - - self.clock.pump((1.0,)) - - self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") - self.failureResultOf(timing_out_d, defer.TimeoutError) - - def test_times_out_when_canceller_throws(self): - """Test that we have successfully worked around - https://twistedmatrix.com/trac/ticket/9534""" - - def canceller(_d): - raise Exception("can't cancel this deferred") - - non_completing_d = Deferred(canceller) - timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) - - self.assertNoResult(timing_out_d) - - self.clock.pump((1.0,)) - - self.failureResultOf(timing_out_d, defer.TimeoutError) - - def test_logcontext_is_preserved_on_cancellation(self): - blocking_was_cancelled = [False] - - @defer.inlineCallbacks - def blocking(): - non_completing_d = Deferred() - with PreserveLoggingContext(): - try: - yield non_completing_d - except CancelledError: - blocking_was_cancelled[0] = True - raise - - with LoggingContext("one") as context_one: - # the errbacks should be run in the test logcontext - def errback(res, deferred_name): - self.assertIs( - current_context(), - context_one, - "errback %s run in unexpected logcontext %s" - % (deferred_name, current_context()), - ) - return res - - original_deferred = blocking() - original_deferred.addErrback(errback, "orig") - timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock) - self.assertNoResult(timing_out_d) - self.assertIs(current_context(), SENTINEL_CONTEXT) - timing_out_d.addErrback(errback, "timingout") - - self.clock.pump((1.0,)) - - self.assertTrue( - blocking_was_cancelled[0], "non-completing deferred was not cancelled" - ) - self.failureResultOf(timing_out_d, defer.TimeoutError) - self.assertIs(current_context(), context_one) -- cgit 1.4.1