summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/notifier.py127
-rw-r--r--synapse/util/__init__.py8
-rw-r--r--synapse/util/async.py16
3 files changed, 78 insertions, 73 deletions
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 078abfc56d..d6655f3f5a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor
+from synapse.util.async import run_on_reactor, ObservableDeferred
 from synapse.types import StreamToken
 import synapse.metrics
 
@@ -45,21 +45,11 @@ class _NotificationListener(object):
     The events stream handler will have yielded to the deferred, so to
     notify the handler it is sufficient to resolve the deferred.
     """
+    __slots__ = ["deferred"]
 
     def __init__(self, deferred):
         self.deferred = deferred
 
-    def notified(self):
-        return self.deferred.called
-
-    def notify(self, token):
-        """ Inform whoever is listening about the new events.
-        """
-        try:
-            self.deferred.callback(token)
-        except defer.AlreadyCalledError:
-            pass
-
 
 class _NotifierUserStream(object):
     """This represents a user connected to the event stream.
@@ -75,11 +65,12 @@ class _NotifierUserStream(object):
                  appservice=None):
         self.user = str(user)
         self.appservice = appservice
-        self.listeners = set()
         self.rooms = set(rooms)
         self.current_token = current_token
         self.last_notified_ms = time_now_ms
 
+        self.notify_deferred = ObservableDeferred(defer.Deferred())
+
     def notify(self, stream_key, stream_id, time_now_ms):
         """Notify any listeners for this user of a new event from an
         event source.
@@ -91,12 +82,10 @@ class _NotifierUserStream(object):
         self.current_token = self.current_token.copy_and_advance(
             stream_key, stream_id
         )
-        if self.listeners:
-            self.last_notified_ms = time_now_ms
-            listeners = self.listeners
-            self.listeners = set()
-            for listener in listeners:
-                listener.notify(self.current_token)
+        self.last_notified_ms = time_now_ms
+        noify_deferred = self.notify_deferred
+        self.notify_deferred = ObservableDeferred(defer.Deferred())
+        noify_deferred.callback(self.current_token)
 
     def remove(self, notifier):
         """ Remove this listener from all the indexes in the Notifier
@@ -114,6 +103,18 @@ class _NotifierUserStream(object):
                 self.appservice, set()
             ).discard(self)
 
+    def count_listeners(self):
+        return len(self.noify_deferred.observers())
+
+    def new_listener(self, token):
+        """Returns a deferred that is resolved when there is a new token
+        greater than the given token.
+        """
+        if self.current_token.is_after(token):
+            return _NotificationListener(defer.succeed(self.current_token))
+        else:
+            return _NotificationListener(self.notify_deferred.observe())
+
 
 class Notifier(object):
     """ This class is responsible for notifying any listeners when there are
@@ -158,7 +159,7 @@ class Notifier(object):
             for x in self.appservice_to_user_streams.values():
                 all_user_streams |= x
 
-            return sum(len(stream.listeners) for stream in all_user_streams)
+            return sum(stream.count_listeners() for stream in all_user_streams)
         metrics.register_callback("listeners", count_listeners)
 
         metrics.register_callback(
@@ -286,10 +287,6 @@ class Notifier(object):
         """Wait until the callback returns a non empty response or the
         timeout fires.
         """
-
-        deferred = defer.Deferred()
-        time_now_ms = self.clock.time_msec()
-
         user = str(user)
         user_stream = self.user_to_user_stream.get(user)
         if user_stream is None:
@@ -302,55 +299,44 @@ class Notifier(object):
                 rooms=rooms,
                 appservice=appservice,
                 current_token=current_token,
-                time_now_ms=time_now_ms,
+                time_now_ms=self.clock.time_msec(),
             )
             self._register_with_keys(user_stream)
+
+        result = None
+        if timeout:
+            # Will be set to a _NotificationListener that we'll be waiting on.
+            # Allows us to cancel it.
+            listener = None
+
+            def timed_out():
+                if listener:
+                    listener.deferred.cancel()
+            timer = self.clock.call_later(timeout/1000., timed_out)
+
+            prev_token = from_token
+            while not result:
+                try:
+                    current_token = user_stream.current_token
+
+                    result = yield callback(prev_token, current_token)
+                    if result:
+                        break
+
+                    # Now we wait for the _NotifierUserStream to be told there
+                    # is a new token.
+                    # We need to supply the token we supplied to callback so
+                    # that we don't miss any current_token updates.
+                    prev_token = current_token
+                    listener = user_stream.new_listener(prev_token)
+                    yield listener.deferred
+                except defer.CancelledError:
+                    break
+
+            self.clock.cancel_call_later(timer, ignore_errs=True)
         else:
             current_token = user_stream.current_token
-
-        listener = [_NotificationListener(deferred)]
-
-        if timeout and not current_token.is_after(from_token):
-            user_stream.listeners.add(listener[0])
-
-        if current_token.is_after(from_token):
             result = yield callback(from_token, current_token)
-        else:
-            result = None
-
-        timer = [None]
-
-        if result:
-            user_stream.listeners.discard(listener[0])
-            defer.returnValue(result)
-            return
-
-        if timeout:
-            timed_out = [False]
-
-            def _timeout_listener():
-                timed_out[0] = True
-                timer[0] = None
-                user_stream.listeners.discard(listener[0])
-                listener[0].notify(current_token)
-
-            # We create multiple notification listeners so we have to manage
-            # canceling the timeout ourselves.
-            timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
-
-            while not result and not timed_out[0]:
-                new_token = yield deferred
-                deferred = defer.Deferred()
-                listener[0] = _NotificationListener(deferred)
-                user_stream.listeners.add(listener[0])
-                result = yield callback(current_token, new_token)
-                current_token = new_token
-
-        if timer[0] is not None:
-            try:
-                self.clock.cancel_call_later(timer[0])
-            except:
-                logger.exception("Failed to cancel notifer timer")
 
         defer.returnValue(result)
 
@@ -368,6 +354,9 @@ class Notifier(object):
 
         @defer.inlineCallbacks
         def check_for_updates(before_token, after_token):
+            if not after_token.is_after(before_token):
+                defer.returnValue(None)
+
             events = []
             end_token = from_token
             for name, source in self.event_sources.sources.items():
@@ -402,7 +391,7 @@ class Notifier(object):
         expired_streams = []
         expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
         for stream in self.user_to_user_stream.values():
-            if stream.listeners:
+            if stream.count_listeners():
                 continue
             if stream.last_notified_ms < expire_before_ts:
                 expired_streams.append(stream)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 260714ccc2..07ff25cef3 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -91,8 +91,12 @@ class Clock(object):
         with PreserveLoggingContext():
             return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
 
-    def cancel_call_later(self, timer):
-        timer.cancel()
+    def cancel_call_later(self, timer, ignore_errs=False):
+        try:
+            timer.cancel()
+        except:
+            if not ignore_errs:
+                raise
 
     def time_bound_deferred(self, given_deferred, time_out):
         if given_deferred.called:
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 1c2044e5b4..5a1d545c96 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -38,6 +38,9 @@ class ObservableDeferred(object):
     deferred.
 
     If consumeErrors is true errors will be captured from the origin deferred.
+
+    Cancelling or otherwise resolving an observer will not affect the original
+    ObservableDeferred.
     """
 
     __slots__ = ["_deferred", "_observers", "_result"]
@@ -45,7 +48,7 @@ class ObservableDeferred(object):
     def __init__(self, deferred, consumeErrors=False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
-        object.__setattr__(self, "_observers", [])
+        object.__setattr__(self, "_observers", set())
 
         def callback(r):
             self._result = (True, r)
@@ -74,12 +77,21 @@ class ObservableDeferred(object):
     def observe(self):
         if not self._result:
             d = defer.Deferred()
-            self._observers.append(d)
+
+            def remove(r):
+                self._observers.discard(d)
+                return r
+            d.addBoth(remove)
+
+            self._observers.add(d)
             return d
         else:
             success, res = self._result
             return defer.succeed(res) if success else defer.fail(res)
 
+    def observers(self):
+        return self._observers
+
     def __getattr__(self, name):
         return getattr(self._deferred, name)