diff options
author | Erik Johnston <erik@matrix.org> | 2015-06-18 15:49:05 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2015-06-18 15:49:24 +0100 |
commit | 22049ea700173017cf2f8e88fb8848e06b82f9b3 (patch) | |
tree | 5a6375689c4009777690da704bc2baec86e5a014 | |
parent | Fix notifier leak (diff) | |
download | synapse-22049ea700173017cf2f8e88fb8848e06b82f9b3.tar.xz |
Refactor the notifier.wait_for_events code to be clearer. Add _NotifierUserStream.new_listener that accpets a token to avoid races.
-rw-r--r-- | synapse/notifier.py | 122 | ||||
-rw-r--r-- | synapse/util/__init__.py | 8 | ||||
-rw-r--r-- | synapse/util/async.py | 13 |
3 files changed, 73 insertions, 70 deletions
diff --git a/synapse/notifier.py b/synapse/notifier.py index 27c034ed51..e441561029 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function -from synapse.util.async import run_on_reactor +from synapse.util.async import run_on_reactor, ObservableDeferred from synapse.types import StreamToken import synapse.metrics @@ -45,20 +45,16 @@ class _NotificationListener(object): The events stream handler will have yielded to the deferred, so to notify the handler it is sufficient to resolve the deferred. """ + __slots__ = ["deferred"] def __init__(self, deferred): - self.deferred = deferred + object.__setattr__(self, "deferred", deferred) - def notified(self): - return self.deferred.called + def __getattr__(self, name): + return getattr(self.deferred, name) - def notify(self, token): - """ Inform whoever is listening about the new events. - """ - try: - self.deferred.callback(token) - except defer.AlreadyCalledError: - pass + def __setattr__(self, name, value): + setattr(self.deferred, name, value) class _NotifierUserStream(object): @@ -75,11 +71,12 @@ class _NotifierUserStream(object): appservice=None): self.user = str(user) self.appservice = appservice - self.listeners = set() self.rooms = set(rooms) self.current_token = current_token self.last_notified_ms = time_now_ms + self.notify_deferred = ObservableDeferred(defer.Deferred()) + def notify(self, stream_key, stream_id, time_now_ms): """Notify any listeners for this user of a new event from an event source. @@ -91,12 +88,10 @@ class _NotifierUserStream(object): self.current_token = self.current_token.copy_and_advance( stream_key, stream_id ) - if self.listeners: - self.last_notified_ms = time_now_ms - listeners = self.listeners - self.listeners = set() - for listener in listeners: - listener.notify(self.current_token) + self.last_notified_ms = time_now_ms + noify_deferred = self.notify_deferred + self.notify_deferred = ObservableDeferred(defer.Deferred()) + noify_deferred.callback(self.current_token) def remove(self, notifier): """ Remove this listener from all the indexes in the Notifier @@ -114,6 +109,18 @@ class _NotifierUserStream(object): self.appservice, set() ).discard(self) + def count_listeners(self): + return len(self.noify_deferred.observers()) + + def new_listener(self, token): + """Returns a deferred that is resolved when there is a new token + greater than the given token. + """ + if self.current_token.is_after(token): + return _NotificationListener(defer.succeed(self.current_token)) + else: + return _NotificationListener(self.notify_deferred.observe()) + class Notifier(object): """ This class is responsible for notifying any listeners when there are @@ -158,7 +165,7 @@ class Notifier(object): for x in self.appservice_to_user_streams.values(): all_user_streams |= x - return sum(len(stream.listeners) for stream in all_user_streams) + return sum(stream.count_listeners() for stream in all_user_streams) metrics.register_callback("listeners", count_listeners) metrics.register_callback( @@ -286,10 +293,6 @@ class Notifier(object): """Wait until the callback returns a non empty response or the timeout fires. """ - - deferred = defer.Deferred() - time_now_ms = self.clock.time_msec() - user = str(user) user_stream = self.user_to_user_stream.get(user) if user_stream is None: @@ -302,54 +305,38 @@ class Notifier(object): rooms=rooms, appservice=appservice, current_token=current_token, - time_now_ms=time_now_ms, + time_now_ms=self.clock.time_msec(), ) self._register_with_keys(user_stream) - else: - current_token = user_stream.current_token result = None - if current_token.is_after(from_token): - result = yield callback(from_token, current_token) - - if result: - defer.returnValue(result) - if timeout: - timer = [None] - listeners = [] - timed_out = [False] - - def notify_listeners(): - user_stream.listeners.difference_update(listeners) - for listener in listeners: - listener.notify(current_token) - del listeners[:] - - def _timeout_listener(): - timed_out[0] = True - timer[0] = None - notify_listeners() - - # We create multiple notification listeners so we have to manage - # canceling the timeout ourselves. - timer[0] = self.clock.call_later(timeout/1000., _timeout_listener) - - while not result and not timed_out[0]: - deferred = defer.Deferred() - notify_listeners() - listeners.append(_NotificationListener(deferred)) - user_stream.listeners.update(listeners) - new_token = yield deferred - - result = yield callback(current_token, new_token) - current_token = new_token - - if timer[0] is not None: + listener = None + timer = self.clock.call_later( + timeout/1000., lambda: listener.cancel() + ) + + prev_token = from_token + while not result: try: - self.clock.cancel_call_later(timer[0]) - except: - logger.exception("Failed to cancel notifer timer") + # We need to start listening to the streams *before* doing + # the callback, as otherwise we may miss something. + current_token = user_stream.current_token + + result = yield callback(prev_token, current_token) + if result: + break + + prev_token = current_token + listener = user_stream.new_listener(prev_token) + yield listener.deferred + except defer.CancelledError: + break + + self.clock.cancel_call_later(timer, ignore_errs=True) + else: + current_token = user_stream.current_token + result = yield callback(from_token, current_token) defer.returnValue(result) @@ -367,6 +354,9 @@ class Notifier(object): @defer.inlineCallbacks def check_for_updates(before_token, after_token): + if not after_token.is_after(before_token): + defer.returnValue(None) + events = [] end_token = from_token for name, source in self.event_sources.sources.items(): @@ -401,7 +391,7 @@ class Notifier(object): expired_streams = [] expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS for stream in self.user_to_user_stream.values(): - if stream.listeners: + if stream.count_listeners(): continue if stream.last_notified_ms < expire_before_ts: expired_streams.append(stream) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 260714ccc2..07ff25cef3 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -91,8 +91,12 @@ class Clock(object): with PreserveLoggingContext(): return reactor.callLater(delay, wrapped_callback, *args, **kwargs) - def cancel_call_later(self, timer): - timer.cancel() + def cancel_call_later(self, timer, ignore_errs=False): + try: + timer.cancel() + except: + if not ignore_errs: + raise def time_bound_deferred(self, given_deferred, time_out): if given_deferred.called: diff --git a/synapse/util/async.py b/synapse/util/async.py index 1c2044e5b4..6f567bcaa6 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -45,7 +45,7 @@ class ObservableDeferred(object): def __init__(self, deferred, consumeErrors=False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) - object.__setattr__(self, "_observers", []) + object.__setattr__(self, "_observers", set()) def callback(r): self._result = (True, r) @@ -74,12 +74,21 @@ class ObservableDeferred(object): def observe(self): if not self._result: d = defer.Deferred() - self._observers.append(d) + + def remove(r): + self._observers.discard(d) + return r + d.addBoth(remove) + + self._observers.add(d) return d else: success, res = self._result return defer.succeed(res) if success else defer.fail(res) + def observers(self): + return self._observers + def __getattr__(self, name): return getattr(self._deferred, name) |